From 9388b74dd017f166e064555c683c30acac62bf7d Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Sun, 9 Feb 2025 15:51:52 +0100 Subject: [PATCH 01/17] KEP-2170: Use SSA to reconcile TrainJob components Signed-off-by: Antonin Stefanutti --- pkg/controller/trainjob_controller.go | 26 +- pkg/runtime/core/clustertrainingruntime.go | 3 +- pkg/runtime/core/trainingruntime.go | 12 +- pkg/runtime/framework/core/framework.go | 18 +- pkg/runtime/framework/core/framework_test.go | 2 +- pkg/runtime/framework/interface.go | 3 +- .../plugins/coscheduling/coscheduling.go | 60 ++--- .../framework/plugins/jobset/builder.go | 240 ++++++++++-------- .../framework/plugins/jobset/jobset.go | 80 +++--- pkg/runtime/framework/plugins/mpi/mpi.go | 4 +- pkg/runtime/interface.go | 3 +- 11 files changed, 220 insertions(+), 231 deletions(-) diff --git a/pkg/controller/trainjob_controller.go b/pkg/controller/trainjob_controller.go index addb7675eb..cb3438b19f 100644 --- a/pkg/controller/trainjob_controller.go +++ b/pkg/controller/trainjob_controller.go @@ -114,32 +114,18 @@ func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobru if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil { return buildFailed, err } + logKeysAndValues := []any{ "groupVersionKind", gvk.String(), "namespace", obj.GetNamespace(), "name", obj.GetName(), } - // TODO (tenzen-y): Ideally, we should use the SSA instead of checking existence. - // Non-empty resourceVersion indicates UPDATE operation. - var creationErr error - var created bool - if obj.GetResourceVersion() == "" { - creationErr = r.client.Create(ctx, obj) - created = creationErr == nil - } - switch { - case created: - log.V(5).Info("Succeeded to create object", logKeysAndValues...) - continue - case client.IgnoreAlreadyExists(creationErr) != nil: - return creationFailed, creationErr - default: - // This indicates CREATE operation has not been performed or the object has already existed in the cluster. - if err = r.client.Update(ctx, obj); err != nil { - return updateFailed, err - } - log.V(5).Info("Succeeded to update object", logKeysAndValues...) + + if err := r.client.Patch(ctx, obj, client.Apply, client.FieldOwner("trainer"), client.ForceOwnership); err != nil { + return buildFailed, err } + + log.V(5).Info("Succeeded to update object", logKeysAndValues...) } return creationSucceeded, nil } diff --git a/pkg/runtime/core/clustertrainingruntime.go b/pkg/runtime/core/clustertrainingruntime.go index 6fe0be1501..8bea04c382 100644 --- a/pkg/runtime/core/clustertrainingruntime.go +++ b/pkg/runtime/core/clustertrainingruntime.go @@ -22,6 +22,7 @@ import ( "fmt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" @@ -52,7 +53,7 @@ func NewClusterTrainingRuntime(context.Context, client.Client, client.FieldIndex }, nil } -func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]client.Object, error) { +func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) { var clTrainingRuntime trainer.ClusterTrainingRuntime if err := r.client.Get(ctx, client.ObjectKey{Name: trainJob.Spec.RuntimeRef.Name}, &clTrainingRuntime); err != nil { return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedClusterTrainingRuntime, err) diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index 7736a52b24..157e742455 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -22,11 +22,11 @@ import ( "fmt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" - jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/pkg/runtime" @@ -71,7 +71,7 @@ func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.Fie return trainingRuntimeFactory, nil } -func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]client.Object, error) { +func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) { var trainingRuntime trainer.TrainingRuntime err := r.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.RuntimeRef.Name}, &trainingRuntime) if err != nil { @@ -82,7 +82,7 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.Trai func (r *TrainingRuntime) buildObjects( ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy, -) ([]client.Object, error) { +) ([]*unstructured.Unstructured, error) { propagationLabels := jobSetTemplateSpec.Labels if propagationLabels == nil && trainJob.Spec.Labels != nil { propagationLabels = make(map[string]string, len(trainJob.Spec.Labels)) @@ -121,11 +121,7 @@ func (r *TrainingRuntime) buildObjects( return nil, err } - jobSetTemplate := jobsetv1alpha2.JobSet{ - Spec: jobSetTemplateSpec.Spec, - } - - return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob) + return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob) } func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { diff --git a/pkg/runtime/framework/core/framework.go b/pkg/runtime/framework/core/framework.go index 80e3bb4b60..8c334729ca 100644 --- a/pkg/runtime/framework/core/framework.go +++ b/pkg/runtime/framework/core/framework.go @@ -21,6 +21,8 @@ import ( "errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + k8sruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -112,15 +114,17 @@ func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *trainer.TrainJob) return aggregatedWarnings, aggregatedErrors } -func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *trainer.TrainJob) ([]client.Object, error) { - var objs []client.Object +func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) { + var objs []*unstructured.Unstructured for _, plugin := range f.componentBuilderPlugins { - obj, err := plugin.Build(ctx, runtimeJobTemplate, info, trainJob) - if err != nil { + if component, err := plugin.Build(ctx, info, trainJob); err != nil { return nil, err - } - if obj != nil { - objs = append(objs, obj...) + } else if component != nil { + if content, err := k8sruntime.DefaultUnstructuredConverter.ToUnstructured(component); err != nil { + return nil, err + } else { + objs = append(objs, &unstructured.Unstructured{Object: content}) + } } } return objs, nil diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index 0712404d62..3cc9d90440 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -511,7 +511,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { if err = fwk.RunEnforceMLPolicyPlugins(tc.runtimeInfo, tc.trainJob); err != nil { t.Fatal(err) } - objs, err := fwk.RunComponentBuilderPlugins(ctx, tc.runtimeJobTemplate, tc.runtimeInfo, tc.trainJob) + objs, err := fwk.RunComponentBuilderPlugins(ctx, tc.runtimeInfo, tc.trainJob) if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { t.Errorf("Unexpected errors (-want,+got):\n%s", diff) } diff --git a/pkg/runtime/framework/interface.go b/pkg/runtime/framework/interface.go index 9f208dd174..17b05f07d2 100644 --- a/pkg/runtime/framework/interface.go +++ b/pkg/runtime/framework/interface.go @@ -21,7 +21,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" - "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" @@ -54,7 +53,7 @@ type EnforceMLPolicyPlugin interface { type ComponentBuilderPlugin interface { Plugin - Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *trainer.TrainJob) ([]client.Object, error) + Build(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) } type TerminalConditionPlugin interface { diff --git a/pkg/runtime/framework/plugins/coscheduling/coscheduling.go b/pkg/runtime/framework/plugins/coscheduling/coscheduling.go index 11f8aaca31..366fb30893 100644 --- a/pkg/runtime/framework/plugins/coscheduling/coscheduling.go +++ b/pkg/runtime/framework/plugins/coscheduling/coscheduling.go @@ -20,18 +20,15 @@ import ( "context" "errors" "fmt" - "maps" "slices" "github.com/go-logr/logr" corev1 "k8s.io/api/core/v1" nodev1 "k8s.io/api/node/v1" - "k8s.io/apimachinery/pkg/api/equality" - apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" + metav1ac "k8s.io/client-go/applyconfigurations/meta/v1" "k8s.io/client-go/util/workqueue" "k8s.io/klog/v2" "k8s.io/utils/ptr" @@ -39,15 +36,14 @@ import ( "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" - ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/controller-runtime/pkg/source" schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" + schedulerpluginsv1alpha1ac "sigs.k8s.io/scheduler-plugins/pkg/generated/applyconfiguration/scheduling/v1alpha1" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" - "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" "github.com/kubeflow/trainer/pkg/runtime/framework" runtimeindexer "github.com/kubeflow/trainer/pkg/runtime/indexer" @@ -105,7 +101,7 @@ func (c *CoScheduling) EnforcePodGroupPolicy(info *runtime.Info, trainJob *train return nil } -func (c *CoScheduling) Build(ctx context.Context, _ client.Object, info *runtime.Info, trainJob *trainer.TrainJob) ([]client.Object, error) { +func (c *CoScheduling) Build(_ context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) { if info == nil || info.RuntimePolicy.PodGroupPolicy == nil || info.RuntimePolicy.PodGroupPolicy.Coscheduling == nil || trainJob == nil { return nil, nil } @@ -121,41 +117,23 @@ func (c *CoScheduling) Build(ctx context.Context, _ client.Object, info *runtime totalResources[resName] = current } } - newPG := &schedulerpluginsv1alpha1.PodGroup{ - TypeMeta: metav1.TypeMeta{ - APIVersion: schedulerpluginsv1alpha1.SchemeGroupVersion.String(), - Kind: constants.PodGroupKind, - }, - ObjectMeta: metav1.ObjectMeta{ - Name: trainJob.Name, - Namespace: trainJob.Namespace, - }, - Spec: schedulerpluginsv1alpha1.PodGroupSpec{ - ScheduleTimeoutSeconds: info.RuntimePolicy.PodGroupPolicy.Coscheduling.ScheduleTimeoutSeconds, - MinMember: totalMembers, - MinResources: totalResources, - }, - } - if err := ctrlutil.SetControllerReference(trainJob, newPG, c.scheme); err != nil { - return nil, err - } - oldPG := &schedulerpluginsv1alpha1.PodGroup{} - if err := c.client.Get(ctx, client.ObjectKeyFromObject(newPG), oldPG); err != nil { - if !apierrors.IsNotFound(err) { - return nil, err - } - oldPG = nil - } - if needsCreateOrUpdate(oldPG, newPG, ptr.Deref(trainJob.Spec.Suspend, false)) { - return []client.Object{newPG}, nil - } - return nil, nil -} -func needsCreateOrUpdate(old, new *schedulerpluginsv1alpha1.PodGroup, trainJobIsSuspended bool) bool { - return old == nil || - trainJobIsSuspended && - (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations)) + podGroup := schedulerpluginsv1alpha1ac.PodGroup(trainJob.Name, trainJob.Namespace) + + podGroup.WithSpec(schedulerpluginsv1alpha1ac.PodGroupSpec(). + WithScheduleTimeoutSeconds(totalMembers). + WithMinResources(totalResources). + WithScheduleTimeoutSeconds(*info.RuntimePolicy.PodGroupPolicy.Coscheduling.ScheduleTimeoutSeconds)) + + podGroup.WithOwnerReferences(metav1ac.OwnerReference(). + WithAPIVersion(trainer.GroupVersion.String()). + WithKind(trainer.TrainJobKind). + WithName(trainJob.Name). + WithUID(trainJob.UID). + WithController(true). + WithBlockOwnerDeletion(true)) + + return []any{podGroup}, nil } type PodGroupRuntimeClassHandler struct { diff --git a/pkg/runtime/framework/plugins/jobset/builder.go b/pkg/runtime/framework/plugins/jobset/builder.go index aa7212d548..7dc9a8364d 100644 --- a/pkg/runtime/framework/plugins/jobset/builder.go +++ b/pkg/runtime/framework/plugins/jobset/builder.go @@ -17,13 +17,11 @@ limitations under the License. package jobset import ( - "maps" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/sets" - "sigs.k8s.io/controller-runtime/pkg/client" - jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + corev1ac "k8s.io/client-go/applyconfigurations/core/v1" + "k8s.io/utils/ptr" + jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/pkg/constants" @@ -31,104 +29,63 @@ import ( ) type Builder struct { - jobsetv1alpha2.JobSet + *jobsetv1alpha2ac.JobSetApplyConfiguration } -func NewBuilder(objectKey client.ObjectKey, jobSetTemplateSpec trainer.JobSetTemplateSpec) *Builder { +func NewBuilder(jobSet *jobsetv1alpha2ac.JobSetApplyConfiguration) *Builder { return &Builder{ - JobSet: jobsetv1alpha2.JobSet{ - TypeMeta: metav1.TypeMeta{ - APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(), - Kind: constants.JobSetKind, - }, - ObjectMeta: metav1.ObjectMeta{ - Namespace: objectKey.Namespace, - Name: objectKey.Name, - Labels: maps.Clone(jobSetTemplateSpec.Labels), - Annotations: maps.Clone(jobSetTemplateSpec.Annotations), - }, - Spec: *jobSetTemplateSpec.Spec.DeepCopy(), - }, - } -} - -// mergeInitializerEnvs merges the TrainJob and Runtime Pod envs. -func mergeInitializerEnvs(storageUri *string, trainJobEnvs, containerEnv []corev1.EnvVar) []corev1.EnvVar { - envNames := sets.New[string]() - var envs []corev1.EnvVar - // Add the Storage URI env. - if storageUri != nil { - envNames.Insert(InitializerEnvStorageUri) - envs = append(envs, corev1.EnvVar{ - Name: InitializerEnvStorageUri, - Value: *storageUri, - }) - } - // Add the rest TrainJob envs. - // TODO (andreyvelich): Validate that TrainJob dataset and model envs don't have the STORAGE_URI env. - for _, e := range trainJobEnvs { - envNames.Insert(e.Name) - envs = append(envs, e) - } - - // TrainJob envs take precedence over the TrainingRuntime envs. - for _, e := range containerEnv { - if !envNames.Has(e.Name) { - envs = append(envs, e) - } + JobSetApplyConfiguration: jobSet, } - return envs } // Initializer updates JobSet values for the initializer Job. func (b *Builder) Initializer(trainJob *trainer.TrainJob) *Builder { for i, rJob := range b.Spec.ReplicatedJobs { - if rJob.Name == constants.JobInitializer { + if *rJob.Name == constants.JobInitializer { // TODO (andreyvelich): Currently, we use initContainers for the initializers. // Once JobSet supports execution policy for the ReplicatedJobs, we should migrate to containers. // Ref: https://github.com/kubernetes-sigs/jobset/issues/672 for j, container := range rJob.Template.Spec.Template.Spec.InitContainers { // Update values for the dataset initializer container. - if container.Name == constants.ContainerDatasetInitializer && trainJob.Spec.DatasetConfig != nil { + if *container.Name == constants.ContainerDatasetInitializer && trainJob.Spec.DatasetConfig != nil { + env := &b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j].Env // Update the dataset initializer envs. - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j].Env = mergeInitializerEnvs( - trainJob.Spec.DatasetConfig.StorageUri, - trainJob.Spec.DatasetConfig.Env, - container.Env, - ) + if storageUri := trainJob.Spec.DatasetConfig.StorageUri; storageUri != nil { + upsertEnvVars(env, corev1.EnvVar{ + Name: InitializerEnvStorageUri, + Value: *storageUri, + }) + } + upsertEnvVars(env, trainJob.Spec.DatasetConfig.Env...) // Update the dataset initializer secret reference. if trainJob.Spec.DatasetConfig.SecretRef != nil { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j].EnvFrom = append( - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j].EnvFrom, - corev1.EnvFromSource{ - SecretRef: &corev1.SecretEnvSource{ - LocalObjectReference: *trainJob.Spec.DatasetConfig.SecretRef, - }, - }, - ) + b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j]. + WithEnvFrom(corev1ac.EnvFromSource(). + WithSecretRef(corev1ac.SecretEnvSource(). + WithName(trainJob.Spec.DatasetConfig.SecretRef.Name))) } } // TODO (andreyvelich): Add the model exporter when we support it. // Update values for the model initializer container. - if container.Name == constants.ContainerModelInitializer && trainJob.Spec.ModelConfig != nil && trainJob.Spec.ModelConfig.Input != nil { + if *container.Name == constants.ContainerModelInitializer && + trainJob.Spec.ModelConfig != nil && + trainJob.Spec.ModelConfig.Input != nil { // Update the model initializer envs. - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j].Env = mergeInitializerEnvs( - trainJob.Spec.ModelConfig.Input.StorageUri, - trainJob.Spec.ModelConfig.Input.Env, - container.Env, - ) + env := &b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j].Env + if storageUri := trainJob.Spec.ModelConfig.Input.StorageUri; storageUri != nil { + upsertEnvVars(env, corev1.EnvVar{ + Name: InitializerEnvStorageUri, + Value: *storageUri, + }) + } + upsertEnvVars(env, trainJob.Spec.ModelConfig.Input.Env...) // Update the model initializer secret reference. if trainJob.Spec.ModelConfig.Input.SecretRef != nil { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j].EnvFrom = append( - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j].EnvFrom, - corev1.EnvFromSource{ - SecretRef: &corev1.SecretEnvSource{ - LocalObjectReference: *trainJob.Spec.ModelConfig.Input.SecretRef, - }, - }, - ) + b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j]. + WithEnvFrom(corev1ac.EnvFromSource(). + WithSecretRef(corev1ac.SecretEnvSource(). + WithName(trainJob.Spec.ModelConfig.Input.SecretRef.Name))) } - } } } @@ -184,7 +141,7 @@ func (b *Builder) Launcher(info *runtime.Info, trainJob *trainer.TrainJob) *Buil // Trainer updates JobSet values for the trainer Job. func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Builder { for i, rJob := range b.Spec.ReplicatedJobs { - if rJob.Name == constants.JobTrainerNode { + if *rJob.Name == constants.JobTrainerNode { // Update the Parallelism and Completions values for the Trainer Job. b.Spec.ReplicatedJobs[i].Template.Spec.Parallelism = info.Trainer.NumNodes b.Spec.ReplicatedJobs[i].Template.Spec.Completions = info.Trainer.NumNodes @@ -195,42 +152,33 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build // Update values for the Trainer container. for j, container := range rJob.Template.Spec.Template.Spec.Containers { - if container.Name == constants.ContainerTrainer { + if *container.Name == constants.ContainerTrainer { // Update values from the TrainJob trainer. if trainJob.Spec.Trainer != nil { - if trainJob.Spec.Trainer.Image != nil { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Image = *trainJob.Spec.Trainer.Image + if image := trainJob.Spec.Trainer.Image; image != nil { + b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Image = image } - if trainJob.Spec.Trainer.Command != nil { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Command = trainJob.Spec.Trainer.Command + if command := trainJob.Spec.Trainer.Command; command != nil { + b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Command = command } - if trainJob.Spec.Trainer.Args != nil { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Args = trainJob.Spec.Trainer.Args + if args := trainJob.Spec.Trainer.Args; args != nil { + b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Args = args } - if trainJob.Spec.Trainer.ResourcesPerNode != nil { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Resources = *trainJob.Spec.Trainer.ResourcesPerNode + if resourcesPerNode := trainJob.Spec.Trainer.ResourcesPerNode; resourcesPerNode != nil { + b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j]. + WithResources(corev1ac.ResourceRequirements(). + WithRequests(resourcesPerNode.Requests). + WithLimits(resourcesPerNode.Limits)) } } // Update values from the Info object. - if info.Trainer.Env != nil { + if env := info.Trainer.Env; env != nil { // Update JobSet envs from the Info. - envNames := sets.New[string]() - for _, env := range info.Trainer.Env { - envNames.Insert(env.Name) - } - trainerEnvs := info.Trainer.Env - // Info envs take precedence over the TrainingRuntime envs. - for _, env := range container.Env { - if !envNames.Has(env.Name) { - trainerEnvs = append(trainerEnvs, env) - } - } - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Env = trainerEnvs + upsertEnvVars(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Env, env...) } // Update the Trainer container port. - if info.Trainer.ContainerPort != nil { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports = append( - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports, *info.Trainer.ContainerPort) + if port := info.Trainer.ContainerPort; port != nil { + upsertPorts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports, *port) } // Update the Trainer container volume mounts. if info.Trainer.VolumeMounts != nil { @@ -245,9 +193,10 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build } // TODO: Supporting merge labels would be great. + func (b *Builder) PodLabels(labels map[string]string) *Builder { for i := range b.Spec.ReplicatedJobs { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Labels = labels + b.Spec.ReplicatedJobs[i].Template.Spec.Template.WithLabels(labels) } return b } @@ -259,6 +208,83 @@ func (b *Builder) Suspend(suspend *bool) *Builder { // TODO: Need to support all TrainJob fields. -func (b *Builder) Build() *jobsetv1alpha2.JobSet { - return &b.JobSet +func (b *Builder) Build() *jobsetv1alpha2ac.JobSetApplyConfiguration { + return b.JobSetApplyConfiguration +} + +func upsertEnvVars(envVarList *[]corev1ac.EnvVarApplyConfiguration, envVars ...corev1.EnvVar) { + for _, e := range envVars { + envVar := corev1ac.EnvVar().WithName(e.Name) + if from := e.ValueFrom; from != nil { + source := corev1ac.EnvVarSource() + if ref := from.FieldRef; ref != nil { + source.WithFieldRef(corev1ac.ObjectFieldSelector().WithFieldPath(ref.FieldPath)) + } + if ref := from.ResourceFieldRef; ref != nil { + source.WithResourceFieldRef(corev1ac.ResourceFieldSelector(). + WithContainerName(ref.ContainerName). + WithResource(ref.Resource). + WithDivisor(ref.Divisor)) + } + if ref := from.ConfigMapKeyRef; ref != nil { + key := corev1ac.ConfigMapKeySelector().WithKey(ref.Key).WithName(ref.Name) + if optional := ref.Optional; optional != nil { + key.WithOptional(*optional) + } + source.WithConfigMapKeyRef(key) + } + if ref := from.SecretKeyRef; ref != nil { + key := corev1ac.SecretKeySelector().WithKey(ref.Key).WithName(ref.Name) + if optional := ref.Optional; optional != nil { + key.WithOptional(*optional) + } + source.WithSecretKeyRef(key) + } + envVar.WithValueFrom(source) + } else { + envVar.WithValue(e.Value) + } + upsert(envVarList, envVar, byEnvVarName) + } +} + +func upsertPorts(portList *[]corev1ac.ContainerPortApplyConfiguration, ports ...corev1.ContainerPort) { + for _, p := range ports { + port := corev1ac.ContainerPort() + if p.ContainerPort > 0 { + port.WithContainerPort(p.ContainerPort) + } + if p.HostPort > 0 { + port.WithHostPort(p.HostPort) + } + if p.HostIP != "" { + port.WithHostIP(p.HostIP) + } + if p.Name != "" { + port.WithName(p.Name) + } + if p.Protocol != "" { + port.WithProtocol(p.Protocol) + } + upsert(portList, port, byContainerPortOrName) + } +} + +func byEnvVarName(a, b corev1ac.EnvVarApplyConfiguration) bool { + return ptr.Equal(a.Name, b.Name) +} + +func byContainerPortOrName(a, b corev1ac.ContainerPortApplyConfiguration) bool { + return ptr.Equal(a.ContainerPort, b.ContainerPort) || ptr.Equal(a.Name, b.Name) +} + +type compare[T any] func(T, T) bool + +func upsert[T any](items *[]T, item *T, predicate compare[T]) { + for i, t := range *items { + if predicate(t, *item) { + (*items)[i] = *item + } + } + *items = append(*items, *item) } diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index e900fa4031..9f514d2989 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -18,23 +18,23 @@ package jobset import ( "context" + "encoding/json" "fmt" "maps" "github.com/go-logr/logr" - "k8s.io/apimachinery/pkg/api/equality" - apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/utils/ptr" + metav1ac "k8s.io/client-go/applyconfigurations/meta/v1" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" - ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/pkg/constants" @@ -85,38 +85,43 @@ func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder { } } -func (j *JobSet) Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *trainer.TrainJob) ([]client.Object, error) { - if runtimeJobTemplate == nil || info == nil || trainJob == nil { +func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) { + if info == nil || trainJob == nil { return nil, fmt.Errorf("runtime info or object is missing") } - raw, ok := runtimeJobTemplate.(*jobsetv1alpha2.JobSet) - if !ok { - return nil, nil + // Get the runtime as unstructured from the TrainJob ref + runtimeJobTemplate := &unstructured.Unstructured{} + runtimeJobTemplate.SetAPIVersion(trainer.GroupVersion.String()) + runtimeJobTemplate.SetKind(*trainJob.Spec.RuntimeRef.Kind) + err := j.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.RuntimeRef.Name}, runtimeJobTemplate) + if err != nil { + return nil, err } - var jobSetBuilder *Builder - oldJobSet := &jobsetv1alpha2.JobSet{} - if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), oldJobSet); err != nil { - if !apierrors.IsNotFound(err) { - return nil, err - } - jobSetBuilder = NewBuilder(client.ObjectKeyFromObject(trainJob), trainer.JobSetTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: info.Labels, - Annotations: info.Annotations, - }, - Spec: raw.Spec, - }) - oldJobSet = nil + // Populate the JobSet template spec apply configuration + jobSetTemplateSpec := &jobsetv1alpha2ac.JobSetSpecApplyConfiguration{} + if jobSetSpec, ok, err := unstructured.NestedFieldCopy(runtimeJobTemplate.Object, "spec", "template", "spec"); err != nil { + return nil, err + } else if !ok { + return nil, fmt.Errorf("trainJob runtime %s does not have a spec.template.spec field", trainJob.Spec.RuntimeRef.Name) } else { - jobSetBuilder = &Builder{ - JobSet: *oldJobSet.DeepCopy(), + if raw, err := json.Marshal(jobSetSpec); err != nil { + return nil, err + } else if err := json.Unmarshal(raw, jobSetTemplateSpec); err != nil { + return nil, err } } + // Init the JobSet apply configuration from the runtime template spec + jobSetBuilder := NewBuilder(jobsetv1alpha2ac.JobSet(trainJob.Name, trainJob.Namespace). + WithLabels(maps.Clone(info.Labels)). + WithAnnotations(maps.Clone(info.Annotations)). + WithSpec(jobSetTemplateSpec)) + // TODO (andreyvelich): Add support for the PodSpecOverride. // TODO (andreyvelich): Refactor the builder with wrappers for PodSpec. + // Apply the runtime info jobSet := jobSetBuilder. Initializer(trainJob). Launcher(info, trainJob). @@ -124,24 +129,17 @@ func (j *JobSet) Build(ctx context.Context, runtimeJobTemplate client.Object, in PodLabels(info.PodLabels). Suspend(trainJob.Spec.Suspend). Build() - if err := ctrlutil.SetControllerReference(trainJob, jobSet, j.scheme); err != nil { - return nil, err - } - if needsCreateOrUpdate(oldJobSet, jobSet, ptr.Deref(trainJob.Spec.Suspend, false)) { - return []client.Object{jobSet}, nil - } - return nil, nil -} - -func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, trainJobIsSuspended bool) bool { - return old == nil || - (!trainJobIsSuspended && jobSetIsSuspended(old) && !jobSetIsSuspended(new)) || - (trainJobIsSuspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations))) -} + // Set the TrainJob as owner + jobSet.WithOwnerReferences(metav1ac.OwnerReference(). + WithAPIVersion(trainer.GroupVersion.String()). + WithKind(trainer.TrainJobKind). + WithName(trainJob.Name). + WithUID(trainJob.UID). + WithController(true). + WithBlockOwnerDeletion(true)) -func jobSetIsSuspended(jobSet *jobsetv1alpha2.JobSet) bool { - return ptr.Deref(jobSet.Spec.Suspend, false) + return []any{jobSet}, nil } func (j *JobSet) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { diff --git a/pkg/runtime/framework/plugins/mpi/mpi.go b/pkg/runtime/framework/plugins/mpi/mpi.go index d26c58c7c4..03dcbd0c1a 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi.go +++ b/pkg/runtime/framework/plugins/mpi/mpi.go @@ -209,7 +209,7 @@ func (m *MPI) ReconcilerBuilders() []runtime.ReconcilerBuilder { } } -func (m *MPI) Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *trainer.TrainJob) ([]client.Object, error) { +func (m *MPI) Build(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) { if info == nil || info.RuntimePolicy.MLPolicy == nil || info.RuntimePolicy.MLPolicy.MPI == nil { return nil, nil } @@ -224,7 +224,7 @@ func (m *MPI) Build(ctx context.Context, runtimeJobTemplate client.Object, info return nil, fmt.Errorf("failed to build ConfigMap with hostfile. Error: %v", err) } - return []client.Object{secret, configMap}, nil + return []any{secret, configMap}, nil } func (m *MPI) buildSSHAuthSecret(ctx context.Context, trainJob *trainer.TrainJob) (*corev1.Secret, error) { diff --git a/pkg/runtime/interface.go b/pkg/runtime/interface.go index ff93e1b010..3b045ee0a0 100644 --- a/pkg/runtime/interface.go +++ b/pkg/runtime/interface.go @@ -20,6 +20,7 @@ import ( "context" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" @@ -32,7 +33,7 @@ import ( type ReconcilerBuilder func(*builder.Builder, client.Client, cache.Cache) *builder.Builder type Runtime interface { - NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]client.Object, error) + NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) EventHandlerRegistrars() []ReconcilerBuilder ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) From 84ee19c26c30edf40eadbdd88dd23e67ea1e386c Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Tue, 11 Feb 2025 14:34:34 +0100 Subject: [PATCH 02/17] Enable Unstructured caching in controller manager config Signed-off-by: Antonin Stefanutti --- cmd/trainer-controller-manager/main.go | 6 + go.mod | 18 +-- go.sum | 40 ++--- ....kubeflow.org_clustertrainingruntimes.yaml | 149 +++++++++++++++++- ...trainer.kubeflow.org_trainingruntimes.yaml | 149 +++++++++++++++++- 5 files changed, 325 insertions(+), 37 deletions(-) diff --git a/cmd/trainer-controller-manager/main.go b/cmd/trainer-controller-manager/main.go index 28bb04b26b..17620af671 100644 --- a/cmd/trainer-controller-manager/main.go +++ b/cmd/trainer-controller-manager/main.go @@ -29,6 +29,7 @@ import ( utilruntime "k8s.io/apimachinery/pkg/util/runtime" clientgoscheme "k8s.io/client-go/kubernetes/scheme" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" ctrlpkg "sigs.k8s.io/controller-runtime/pkg/controller" "sigs.k8s.io/controller-runtime/pkg/healthz" "sigs.k8s.io/controller-runtime/pkg/log/zap" @@ -111,6 +112,11 @@ func main() { } mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrl.Options{ Scheme: scheme, + Client: client.Options{ + Cache: &client.CacheOptions{ + Unstructured: true, + }, + }, Metrics: metricsserver.Options{ BindAddress: metricsAddr, SecureServing: secureMetrics, diff --git a/go.mod b/go.mod index 8dbb67fc9a..eba738cae4 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,8 @@ go 1.23.0 require ( github.com/go-logr/logr v1.4.2 github.com/google/go-cmp v0.6.0 - github.com/onsi/ginkgo/v2 v2.22.0 - github.com/onsi/gomega v1.36.1 + github.com/onsi/ginkgo/v2 v2.22.2 + github.com/onsi/gomega v1.36.2 github.com/open-policy-agent/cert-controller v0.12.0 go.uber.org/zap v1.27.0 golang.org/x/crypto v0.31.0 @@ -18,10 +18,10 @@ require ( k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 sigs.k8s.io/controller-runtime v0.20.2 - sigs.k8s.io/jobset v0.5.2 + sigs.k8s.io/jobset v0.8.0-devel.0.20250212132206-c69f95cd53b4 sigs.k8s.io/kueue v0.6.3 sigs.k8s.io/scheduler-plugins v0.30.6 - sigs.k8s.io/structured-merge-diff/v4 v4.4.2 + sigs.k8s.io/structured-merge-diff/v4 v4.5.0 ) require ( @@ -42,7 +42,7 @@ require ( github.com/google/btree v1.1.3 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/gofuzz v1.2.0 // indirect - github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db // indirect + github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect github.com/google/uuid v1.6.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -52,7 +52,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/prometheus/client_golang v1.20.2 // indirect + github.com/prometheus/client_golang v1.20.5 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect @@ -60,7 +60,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/mod v0.21.0 // indirect + golang.org/x/mod v0.22.0 // indirect golang.org/x/net v0.33.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sync v0.10.0 // indirect @@ -68,9 +68,9 @@ require ( golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.7.0 // indirect - golang.org/x/tools v0.26.0 // indirect + golang.org/x/tools v0.28.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect - google.golang.org/protobuf v1.35.1 // indirect + google.golang.org/protobuf v1.36.1 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 7a5c73f491..722ac0f2d6 100644 --- a/go.sum +++ b/go.sum @@ -42,8 +42,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgYQBbFN4U4JNXUNYpxael3UzMyo= -github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/ZoQgRgVIWFJljSWa/zetS2WTvg= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= @@ -69,10 +69,10 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/onsi/ginkgo/v2 v2.22.0 h1:Yed107/8DjTr0lKCNt7Dn8yQ6ybuDRQoMGrNFKzMfHg= -github.com/onsi/ginkgo/v2 v2.22.0/go.mod h1:7Du3c42kxCUegi0IImZ1wUQzMBVecgIHjR1C+NkhLQo= -github.com/onsi/gomega v1.36.1 h1:bJDPBO7ibjxcbHMgSCoo4Yj18UWbKDlLwX1x9sybDcw= -github.com/onsi/gomega v1.36.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog= +github.com/onsi/ginkgo/v2 v2.22.2 h1:/3X8Panh8/WwhU/3Ssa6rCKqPLuAkVY2I0RoyDLySlU= +github.com/onsi/ginkgo/v2 v2.22.2/go.mod h1:oeMosUL+8LtarXBHu/c0bx2D/K9zyQ6uX3cTyztHwsk= +github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= +github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/open-policy-agent/cert-controller v0.12.0 h1:RKXlBafMcCh+++I1geJetXo77tAjyb4542DQc/+aZIw= github.com/open-policy-agent/cert-controller v0.12.0/go.mod h1:N5bCFXdAXMYx0PdS6ZQ9lrDQQMz+F6deoChym6VleXw= github.com/open-policy-agent/frameworks/constraint v0.0.0-20241101234656-e78c8abd754a h1:gQtOJ50XFyL2Xh3lDD9zP4KQ2PY4mZKQ9hDcWc81Sp8= @@ -82,8 +82,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.20.2 h1:5ctymQzZlyOON1666svgwn3s6IKWgfbjsejTMiXIyjg= -github.com/prometheus/client_golang v1.20.2/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= +github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= @@ -96,8 +96,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -117,8 +117,8 @@ golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= -golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -149,16 +149,16 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= -golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= +golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8= +golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw= gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= -google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= -google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= +google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= @@ -190,15 +190,15 @@ k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 h1:M3sRQVHv7vB20Xc2ybTt7ODCeFj6J k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= sigs.k8s.io/controller-runtime v0.20.2 h1:/439OZVxoEc02psi1h4QO3bHzTgu49bb347Xp4gW1pc= sigs.k8s.io/controller-runtime v0.20.2/go.mod h1:xg2XB0K5ShQzAgsoujxuKN4LNXR2LfwwHsPj7Iaw+XY= -sigs.k8s.io/jobset v0.5.2 h1:276q5Pi/ErLYj+GQ0ydEXR6tx3LwBhEzHLQv+k8bYF4= -sigs.k8s.io/jobset v0.5.2/go.mod h1:Vg99rj/6OoGvy1uvywGEHOcVLCWWJYkJtisKqdWzcFw= +sigs.k8s.io/jobset v0.8.0-devel.0.20250212132206-c69f95cd53b4 h1:f4fx7+T4Bp6v+nFs5bCPq/py+Xt6DYEHbWhF/CRkAUQ= +sigs.k8s.io/jobset v0.8.0-devel.0.20250212132206-c69f95cd53b4/go.mod h1:egRLNm7qi4s1cj+sPvleUagDF5icYb7UH4FwGlni6+Q= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo= sigs.k8s.io/kueue v0.6.3 h1:PmccdKPDFQIaAboyuSG6M0w6hXtxVA51RV+DjCUtBtQ= sigs.k8s.io/kueue v0.6.3/go.mod h1:rliYfK/K7pJ7CT4ReV1szzciNkAo3sBn5Bmr5Sn6uCY= sigs.k8s.io/scheduler-plugins v0.30.6 h1:P4pViMVoyVNHWmkG96UtJ4LvxkUIeenIUKLZd09vDyw= sigs.k8s.io/scheduler-plugins v0.30.6/go.mod h1:EDYYqHmpHR//VYKAeud1TTQbTFSvpdGFeyEg9ejOmnI= -sigs.k8s.io/structured-merge-diff/v4 v4.4.2 h1:MdmvkGuXi/8io6ixD5wud3vOLwc1rj0aNqRlpuvjmwA= -sigs.k8s.io/structured-merge-diff/v4 v4.4.2/go.mod h1:N8f93tFZh9U6vpxwRArLiikrE5/2tiu1w1AGfACIGE4= +sigs.k8s.io/structured-merge-diff/v4 v4.5.0 h1:nbCitCK2hfnhyiKo6uf2HxUPTCodY6Qaf85SbDIaMBk= +sigs.k8s.io/structured-merge-diff/v4 v4.5.0/go.mod h1:N8f93tFZh9U6vpxwRArLiikrE5/2tiu1w1AGfACIGE4= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= diff --git a/manifests/base/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml b/manifests/base/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml index 370797c8f1..0f06bc3213 100644 --- a/manifests/base/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml +++ b/manifests/base/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml @@ -652,6 +652,31 @@ spec: description: Specification of the desired JobSet which will be created from TrainJob. properties: + coordinator: + description: |- + Coordinator can be used to assign a specific pod as the coordinator for + the JobSet. If defined, an annotation will be added to all Jobs and pods with + coordinator pod, which contains the stable network endpoint where the + coordinator pod can be reached. + jobset.sigs.k8s.io/coordinator=. + properties: + jobIndex: + description: |- + JobIndex is the index of Job which contains the coordinator pod + (i.e., for a ReplicatedJob with N replicas, there are Job indexes 0 to N-1). + type: integer + podIndex: + description: PodIndex is the Job completion index of the + coordinator pod. + type: integer + replicatedJob: + description: |- + ReplicatedJob is the name of the ReplicatedJob which contains + the coordinator pod. + type: string + required: + - replicatedJob + type: object failurePolicy: description: |- FailurePolicy, if set, configures when to declare the JobSet as @@ -665,13 +690,79 @@ spec: A restart is achieved by recreating all active child jobs. format: int32 type: integer + restartStrategy: + default: Recreate + description: |- + RestartStrategy defines the strategy to use when restarting the JobSet. + Defaults to Recreate. + enum: + - Recreate + - BlockingRecreate + type: string + rules: + description: |- + List of failure policy rules for this JobSet. + For a given Job failure, the rules will be evaluated in order, + and only the first matching rule will be executed. + If no matching rule is found, the RestartJobSet action is applied. + items: + description: |- + FailurePolicyRule defines a FailurePolicyAction to be executed if a child job + fails due to a reason listed in OnJobFailureReasons. + properties: + action: + description: The action to take if the rule is matched. + enum: + - FailJobSet + - RestartJobSet + - RestartJobSetAndIgnoreMaxRestarts + type: string + name: + description: |- + The name of the failure policy rule. + The name is defaulted to 'failurePolicyRuleN' where N is the index of the failure policy rule. + The name must match the regular expression "^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$". + type: string + onJobFailureReasons: + description: |- + The requirement on the job failure reasons. The requirement + is satisfied if at least one reason matches the list. + The rules are evaluated in order, and the first matching + rule is executed. + An empty list applies the rule to any job failure reason. + items: + type: string + type: array + targetReplicatedJobs: + description: |- + TargetReplicatedJobs are the names of the replicated jobs the operator applies to. + An empty list will apply to all replicatedJobs. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - action + - name + type: object + type: array type: object x-kubernetes-validations: - message: Value is immutable rule: self == oldSelf managedBy: - description: ManagedBy is used to indicate the controller - or entity that manages a JobSet + description: |- + ManagedBy is used to indicate the controller or entity that manages a JobSet. + The built-in JobSet controller reconciles JobSets which don't have this + field at all or the field value is the reserved string + `jobset.sigs.k8s.io/jobset-controller`, but skips reconciling JobSets + with a custom value for this field. + + The value must be a valid domain-prefixed path (e.g. acme.io/foo) - + all characters before the first "/" must be a valid subdomain as defined + by RFC 1123. All characters trailing the first "/" must be valid HTTP Path + characters as defined by RFC 3986. The value cannot exceed 63 characters. + The field is immutable. type: string network: description: Network defines the networking options for the @@ -683,6 +774,11 @@ spec: Pods will be reachable using the fully qualified pod hostname: ---. type: boolean + publishNotReadyAddresses: + description: |- + Indicates if DNS records of pods should be published before the pods are ready. + Defaults to True. + type: boolean subdomain: description: |- Subdomain is an explicit choice for a network subdomain name @@ -698,6 +794,44 @@ spec: form the set. items: properties: + dependsOn: + description: |- + DependsOn is an optional list that specifies the preceding ReplicatedJobs upon which + the current ReplicatedJob depends. If specified, the ReplicatedJob will be created + only after the referenced ReplicatedJobs reach their desired state. + The Order of ReplicatedJobs is defined by their enumeration in the slice. + Note, that the first ReplicatedJob in the slice cannot use the DependsOn API. + Currently, only a single item is supported in the DependsOn list. + If JobSet is suspended the all active ReplicatedJobs will be suspended. When JobSet is + resumed the Job sequence starts again. + This API is mutually exclusive with the StartupPolicy API. + items: + description: DependsOn defines the dependency on the + previous ReplicatedJob status. + properties: + name: + description: Name of the previous ReplicatedJob. + type: string + status: + description: Status defines the condition for + the ReplicatedJob. Only Ready or Complete status + can be set. + enum: + - Ready + - Complete + type: string + required: + - name + - status + type: object + maxItems: 1 + type: array + x-kubernetes-list-map-keys: + - name + x-kubernetes-list-type: map + x-kubernetes-validations: + - message: Value is immutable + rule: self == oldSelf name: description: |- Name is the name of the entry and will be used as a suffix @@ -9799,8 +9933,9 @@ spec: - name x-kubernetes-list-type: map startupPolicy: - description: StartupPolicy, if set, configures in what order - jobs must be started + description: |- + StartupPolicy, if set, configures in what order jobs must be started + Deprecated: StartupPolicy is deprecated, please use the DependsOn API. properties: startupPolicyOrder: description: |- @@ -9864,6 +9999,12 @@ spec: minimum: 0 type: integer type: object + x-kubernetes-validations: + - message: StartupPolicy and DependsOn APIs are mutually exclusive + rule: '!(has(self.startupPolicy) && self.startupPolicy.startupPolicyOrder + == ''InOrder'' && self.replicatedJobs.exists(x, has(x.dependsOn)))' + - message: DependsOn can't be set for the first ReplicatedJob + rule: '!(has(self.replicatedJobs[0].dependsOn))' type: object required: - template diff --git a/manifests/base/crds/trainer.kubeflow.org_trainingruntimes.yaml b/manifests/base/crds/trainer.kubeflow.org_trainingruntimes.yaml index 21febf692d..1f146218c2 100644 --- a/manifests/base/crds/trainer.kubeflow.org_trainingruntimes.yaml +++ b/manifests/base/crds/trainer.kubeflow.org_trainingruntimes.yaml @@ -652,6 +652,31 @@ spec: description: Specification of the desired JobSet which will be created from TrainJob. properties: + coordinator: + description: |- + Coordinator can be used to assign a specific pod as the coordinator for + the JobSet. If defined, an annotation will be added to all Jobs and pods with + coordinator pod, which contains the stable network endpoint where the + coordinator pod can be reached. + jobset.sigs.k8s.io/coordinator=. + properties: + jobIndex: + description: |- + JobIndex is the index of Job which contains the coordinator pod + (i.e., for a ReplicatedJob with N replicas, there are Job indexes 0 to N-1). + type: integer + podIndex: + description: PodIndex is the Job completion index of the + coordinator pod. + type: integer + replicatedJob: + description: |- + ReplicatedJob is the name of the ReplicatedJob which contains + the coordinator pod. + type: string + required: + - replicatedJob + type: object failurePolicy: description: |- FailurePolicy, if set, configures when to declare the JobSet as @@ -665,13 +690,79 @@ spec: A restart is achieved by recreating all active child jobs. format: int32 type: integer + restartStrategy: + default: Recreate + description: |- + RestartStrategy defines the strategy to use when restarting the JobSet. + Defaults to Recreate. + enum: + - Recreate + - BlockingRecreate + type: string + rules: + description: |- + List of failure policy rules for this JobSet. + For a given Job failure, the rules will be evaluated in order, + and only the first matching rule will be executed. + If no matching rule is found, the RestartJobSet action is applied. + items: + description: |- + FailurePolicyRule defines a FailurePolicyAction to be executed if a child job + fails due to a reason listed in OnJobFailureReasons. + properties: + action: + description: The action to take if the rule is matched. + enum: + - FailJobSet + - RestartJobSet + - RestartJobSetAndIgnoreMaxRestarts + type: string + name: + description: |- + The name of the failure policy rule. + The name is defaulted to 'failurePolicyRuleN' where N is the index of the failure policy rule. + The name must match the regular expression "^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$". + type: string + onJobFailureReasons: + description: |- + The requirement on the job failure reasons. The requirement + is satisfied if at least one reason matches the list. + The rules are evaluated in order, and the first matching + rule is executed. + An empty list applies the rule to any job failure reason. + items: + type: string + type: array + targetReplicatedJobs: + description: |- + TargetReplicatedJobs are the names of the replicated jobs the operator applies to. + An empty list will apply to all replicatedJobs. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - action + - name + type: object + type: array type: object x-kubernetes-validations: - message: Value is immutable rule: self == oldSelf managedBy: - description: ManagedBy is used to indicate the controller - or entity that manages a JobSet + description: |- + ManagedBy is used to indicate the controller or entity that manages a JobSet. + The built-in JobSet controller reconciles JobSets which don't have this + field at all or the field value is the reserved string + `jobset.sigs.k8s.io/jobset-controller`, but skips reconciling JobSets + with a custom value for this field. + + The value must be a valid domain-prefixed path (e.g. acme.io/foo) - + all characters before the first "/" must be a valid subdomain as defined + by RFC 1123. All characters trailing the first "/" must be valid HTTP Path + characters as defined by RFC 3986. The value cannot exceed 63 characters. + The field is immutable. type: string network: description: Network defines the networking options for the @@ -683,6 +774,11 @@ spec: Pods will be reachable using the fully qualified pod hostname: ---. type: boolean + publishNotReadyAddresses: + description: |- + Indicates if DNS records of pods should be published before the pods are ready. + Defaults to True. + type: boolean subdomain: description: |- Subdomain is an explicit choice for a network subdomain name @@ -698,6 +794,44 @@ spec: form the set. items: properties: + dependsOn: + description: |- + DependsOn is an optional list that specifies the preceding ReplicatedJobs upon which + the current ReplicatedJob depends. If specified, the ReplicatedJob will be created + only after the referenced ReplicatedJobs reach their desired state. + The Order of ReplicatedJobs is defined by their enumeration in the slice. + Note, that the first ReplicatedJob in the slice cannot use the DependsOn API. + Currently, only a single item is supported in the DependsOn list. + If JobSet is suspended the all active ReplicatedJobs will be suspended. When JobSet is + resumed the Job sequence starts again. + This API is mutually exclusive with the StartupPolicy API. + items: + description: DependsOn defines the dependency on the + previous ReplicatedJob status. + properties: + name: + description: Name of the previous ReplicatedJob. + type: string + status: + description: Status defines the condition for + the ReplicatedJob. Only Ready or Complete status + can be set. + enum: + - Ready + - Complete + type: string + required: + - name + - status + type: object + maxItems: 1 + type: array + x-kubernetes-list-map-keys: + - name + x-kubernetes-list-type: map + x-kubernetes-validations: + - message: Value is immutable + rule: self == oldSelf name: description: |- Name is the name of the entry and will be used as a suffix @@ -9799,8 +9933,9 @@ spec: - name x-kubernetes-list-type: map startupPolicy: - description: StartupPolicy, if set, configures in what order - jobs must be started + description: |- + StartupPolicy, if set, configures in what order jobs must be started + Deprecated: StartupPolicy is deprecated, please use the DependsOn API. properties: startupPolicyOrder: description: |- @@ -9864,6 +9999,12 @@ spec: minimum: 0 type: integer type: object + x-kubernetes-validations: + - message: StartupPolicy and DependsOn APIs are mutually exclusive + rule: '!(has(self.startupPolicy) && self.startupPolicy.startupPolicyOrder + == ''InOrder'' && self.replicatedJobs.exists(x, has(x.dependsOn)))' + - message: DependsOn can't be set for the first ReplicatedJob + rule: '!(has(self.replicatedJobs[0].dependsOn))' type: object required: - template From 3feed39982eb1e6a1a368785d3f2cf2cc67ebcd5 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 14 Feb 2025 08:34:30 +0100 Subject: [PATCH 03/17] Fix PodGroup apply configuration Signed-off-by: Antonin Stefanutti --- pkg/runtime/framework/plugins/coscheduling/coscheduling.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/runtime/framework/plugins/coscheduling/coscheduling.go b/pkg/runtime/framework/plugins/coscheduling/coscheduling.go index 366fb30893..ada4ac70e6 100644 --- a/pkg/runtime/framework/plugins/coscheduling/coscheduling.go +++ b/pkg/runtime/framework/plugins/coscheduling/coscheduling.go @@ -121,7 +121,7 @@ func (c *CoScheduling) Build(_ context.Context, info *runtime.Info, trainJob *tr podGroup := schedulerpluginsv1alpha1ac.PodGroup(trainJob.Name, trainJob.Namespace) podGroup.WithSpec(schedulerpluginsv1alpha1ac.PodGroupSpec(). - WithScheduleTimeoutSeconds(totalMembers). + WithMinMember(totalMembers). WithMinResources(totalResources). WithScheduleTimeoutSeconds(*info.RuntimePolicy.PodGroupPolicy.Coscheduling.ScheduleTimeoutSeconds)) From 1d63af11a957db3f67d04233d89de72fce67789d Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 14 Feb 2025 08:36:56 +0100 Subject: [PATCH 04/17] API to apply config conversion util functions Signed-off-by: Antonin Stefanutti --- .../framework/plugins/jobset/builder.go | 51 +----- pkg/util/apply/apply.go | 153 ++++++++++++++++++ 2 files changed, 156 insertions(+), 48 deletions(-) create mode 100644 pkg/util/apply/apply.go diff --git a/pkg/runtime/framework/plugins/jobset/builder.go b/pkg/runtime/framework/plugins/jobset/builder.go index 7dc9a8364d..e829e8e5d9 100644 --- a/pkg/runtime/framework/plugins/jobset/builder.go +++ b/pkg/runtime/framework/plugins/jobset/builder.go @@ -26,6 +26,7 @@ import ( trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" + "github.com/kubeflow/trainer/pkg/util/apply" ) type Builder struct { @@ -214,59 +215,13 @@ func (b *Builder) Build() *jobsetv1alpha2ac.JobSetApplyConfiguration { func upsertEnvVars(envVarList *[]corev1ac.EnvVarApplyConfiguration, envVars ...corev1.EnvVar) { for _, e := range envVars { - envVar := corev1ac.EnvVar().WithName(e.Name) - if from := e.ValueFrom; from != nil { - source := corev1ac.EnvVarSource() - if ref := from.FieldRef; ref != nil { - source.WithFieldRef(corev1ac.ObjectFieldSelector().WithFieldPath(ref.FieldPath)) - } - if ref := from.ResourceFieldRef; ref != nil { - source.WithResourceFieldRef(corev1ac.ResourceFieldSelector(). - WithContainerName(ref.ContainerName). - WithResource(ref.Resource). - WithDivisor(ref.Divisor)) - } - if ref := from.ConfigMapKeyRef; ref != nil { - key := corev1ac.ConfigMapKeySelector().WithKey(ref.Key).WithName(ref.Name) - if optional := ref.Optional; optional != nil { - key.WithOptional(*optional) - } - source.WithConfigMapKeyRef(key) - } - if ref := from.SecretKeyRef; ref != nil { - key := corev1ac.SecretKeySelector().WithKey(ref.Key).WithName(ref.Name) - if optional := ref.Optional; optional != nil { - key.WithOptional(*optional) - } - source.WithSecretKeyRef(key) - } - envVar.WithValueFrom(source) - } else { - envVar.WithValue(e.Value) - } - upsert(envVarList, envVar, byEnvVarName) + upsert(envVarList, apply.EnvVar(e), byEnvVarName) } } func upsertPorts(portList *[]corev1ac.ContainerPortApplyConfiguration, ports ...corev1.ContainerPort) { for _, p := range ports { - port := corev1ac.ContainerPort() - if p.ContainerPort > 0 { - port.WithContainerPort(p.ContainerPort) - } - if p.HostPort > 0 { - port.WithHostPort(p.HostPort) - } - if p.HostIP != "" { - port.WithHostIP(p.HostIP) - } - if p.Name != "" { - port.WithName(p.Name) - } - if p.Protocol != "" { - port.WithProtocol(p.Protocol) - } - upsert(portList, port, byContainerPortOrName) + upsert(portList, apply.ContainerPort(p), byContainerPortOrName) } } diff --git a/pkg/util/apply/apply.go b/pkg/util/apply/apply.go new file mode 100644 index 0000000000..e8ccd13c79 --- /dev/null +++ b/pkg/util/apply/apply.go @@ -0,0 +1,153 @@ +/* +Copyright 2025 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package apply + +import ( + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + corev1ac "k8s.io/client-go/applyconfigurations/core/v1" + metav1ac "k8s.io/client-go/applyconfigurations/meta/v1" +) + +func ContainerPort(p corev1.ContainerPort) *corev1ac.ContainerPortApplyConfiguration { + port := corev1ac.ContainerPort() + if p.ContainerPort > 0 { + port.WithContainerPort(p.ContainerPort) + } + if p.HostPort > 0 { + port.WithHostPort(p.HostPort) + } + if p.HostIP != "" { + port.WithHostIP(p.HostIP) + } + if p.Name != "" { + port.WithName(p.Name) + } + if p.Protocol != "" { + port.WithProtocol(p.Protocol) + } + return port +} + +func ContainerPorts(p ...corev1.ContainerPort) []*corev1ac.ContainerPortApplyConfiguration { + var ports []*corev1ac.ContainerPortApplyConfiguration + for _, port := range p { + ports = append(ports, ContainerPort(port)) + } + return ports +} + +func EnvVar(e corev1.EnvVar) *corev1ac.EnvVarApplyConfiguration { + envVar := corev1ac.EnvVar().WithName(e.Name) + if from := e.ValueFrom; from != nil { + source := corev1ac.EnvVarSource() + if ref := from.FieldRef; ref != nil { + source.WithFieldRef(corev1ac.ObjectFieldSelector().WithFieldPath(ref.FieldPath)) + } + if ref := from.ResourceFieldRef; ref != nil { + source.WithResourceFieldRef(corev1ac.ResourceFieldSelector(). + WithContainerName(ref.ContainerName). + WithResource(ref.Resource). + WithDivisor(ref.Divisor)) + } + if ref := from.ConfigMapKeyRef; ref != nil { + key := corev1ac.ConfigMapKeySelector().WithKey(ref.Key).WithName(ref.Name) + if optional := ref.Optional; optional != nil { + key.WithOptional(*optional) + } + source.WithConfigMapKeyRef(key) + } + if ref := from.SecretKeyRef; ref != nil { + key := corev1ac.SecretKeySelector().WithKey(ref.Key).WithName(ref.Name) + if optional := ref.Optional; optional != nil { + key.WithOptional(*optional) + } + source.WithSecretKeyRef(key) + } + envVar.WithValueFrom(source) + } else { + envVar.WithValue(e.Value) + } + return envVar +} + +func EnvVars(e ...corev1.EnvVar) []*corev1ac.EnvVarApplyConfiguration { + var envs []*corev1ac.EnvVarApplyConfiguration + for _, env := range e { + envs = append(envs, EnvVar(env)) + } + return envs +} + +func EnvFromSource(e corev1.EnvFromSource) *corev1ac.EnvFromSourceApplyConfiguration { + envVarFrom := corev1ac.EnvFromSource() + if e.Prefix != "" { + envVarFrom.WithPrefix(e.Prefix) + } + if ref := e.ConfigMapRef; ref != nil { + source := corev1ac.ConfigMapEnvSource().WithName(ref.Name) + if ref.Optional != nil { + source.WithOptional(*ref.Optional) + } + envVarFrom.WithConfigMapRef(source) + } + if ref := e.SecretRef; ref != nil { + source := corev1ac.SecretEnvSource().WithName(ref.Name) + if ref.Optional != nil { + source.WithOptional(*ref.Optional) + } + envVarFrom.WithSecretRef(source) + } + return envVarFrom +} + +func EnvFromSources(e ...corev1.EnvFromSource) []*corev1ac.EnvFromSourceApplyConfiguration { + var envs []*corev1ac.EnvFromSourceApplyConfiguration + for _, env := range e { + envs = append(envs, EnvFromSource(env)) + } + return envs +} + +func Condition(c metav1.Condition) *metav1ac.ConditionApplyConfiguration { + condition := metav1ac.Condition(). + WithObservedGeneration(c.ObservedGeneration) + if c.Type != "" { + condition.WithType(c.Type) + } + if c.Message != "" { + condition.WithMessage(c.Message) + } + if c.Reason != "" { + condition.WithReason(c.Reason) + } + if c.Status != "" { + condition.WithStatus(c.Status) + } + if !c.LastTransitionTime.IsZero() { + condition.WithLastTransitionTime(c.LastTransitionTime) + } + return condition +} + +func Conditions(c ...metav1.Condition) []*metav1ac.ConditionApplyConfiguration { + var conditions []*metav1ac.ConditionApplyConfiguration + for _, condition := range c { + conditions = append(conditions, Condition(condition)) + } + return conditions +} From a853dee87d752d09bcec7174d2af0a7ce29e7f19 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 14 Feb 2025 16:23:02 +0100 Subject: [PATCH 05/17] Only add namespace to TrainingRuntime object key Signed-off-by: Antonin Stefanutti --- pkg/runtime/framework/plugins/jobset/jobset.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index 9f514d2989..0406b1c9f7 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -93,9 +93,19 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine // Get the runtime as unstructured from the TrainJob ref runtimeJobTemplate := &unstructured.Unstructured{} runtimeJobTemplate.SetAPIVersion(trainer.GroupVersion.String()) - runtimeJobTemplate.SetKind(*trainJob.Spec.RuntimeRef.Kind) - err := j.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.RuntimeRef.Name}, runtimeJobTemplate) - if err != nil { + + if kind := trainJob.Spec.RuntimeRef.Kind; kind != nil { + runtimeJobTemplate.SetKind(*kind) + } else { + runtimeJobTemplate.SetKind(trainer.ClusterTrainingRuntimeKind) + } + + key := client.ObjectKey{Name: trainJob.Spec.RuntimeRef.Name} + if runtimeJobTemplate.GetKind() == trainer.TrainingRuntimeKind { + key.Namespace = trainJob.Namespace + } + + if err := j.client.Get(ctx, key, runtimeJobTemplate); err != nil { return nil, err } From 66644427f479c8e9f7be243ce4a44b0148ae35ff Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 14 Feb 2025 17:49:39 +0100 Subject: [PATCH 06/17] Fix EnvVar upsert Signed-off-by: Antonin Stefanutti --- pkg/runtime/framework/plugins/jobset/builder.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/runtime/framework/plugins/jobset/builder.go b/pkg/runtime/framework/plugins/jobset/builder.go index e829e8e5d9..706cedbdfe 100644 --- a/pkg/runtime/framework/plugins/jobset/builder.go +++ b/pkg/runtime/framework/plugins/jobset/builder.go @@ -239,6 +239,7 @@ func upsert[T any](items *[]T, item *T, predicate compare[T]) { for i, t := range *items { if predicate(t, *item) { (*items)[i] = *item + return } } *items = append(*items, *item) From 9831a909af0c1150f56697a1dde14d026b5456f9 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 14 Feb 2025 18:21:41 +0100 Subject: [PATCH 07/17] Update unit tests Signed-off-by: Antonin Stefanutti --- .../core/clustertrainingruntime_test.go | 22 ++++--- pkg/runtime/core/trainingruntime_test.go | 31 ++++++--- pkg/runtime/framework/core/framework_test.go | 47 +++++++++----- pkg/util/testing/unstructured.go | 63 +++++++++++++++++++ pkg/util/testing/wrapper.go | 3 - pkg/webhooks/trainingruntime_webhook_test.go | 24 ++++--- 6 files changed, 148 insertions(+), 42 deletions(-) create mode 100644 pkg/util/testing/unstructured.go diff --git a/pkg/runtime/core/clustertrainingruntime_test.go b/pkg/runtime/core/clustertrainingruntime_test.go index 88c1f3eada..658ae20f66 100644 --- a/pkg/runtime/core/clustertrainingruntime_test.go +++ b/pkg/runtime/core/clustertrainingruntime_test.go @@ -25,7 +25,7 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "sigs.k8s.io/controller-runtime/pkg/client" + "k8s.io/apimachinery/pkg/runtime" schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" @@ -41,7 +41,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { cases := map[string]struct { trainJob *trainer.TrainJob clusterTrainingRuntime *trainer.ClusterTrainingRuntime - wantObjs []client.Object + wantObjs []runtime.Object wantError error }{ "succeeded to build PodGroup and JobSet with NumNodes from the Runtime and container from the Trainer.": { @@ -63,7 +63,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { Obj(), ). Obj(), - wantObjs: []client.Object{ + wantObjs: []runtime.Object{ testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). InitContainerDatasetModelInitializer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). NumNodes(100). @@ -95,7 +95,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { }, } cmpOpts := []cmp.Option{ - cmpopts.SortSlices(func(a, b client.Object) bool { + cmpopts.SortSlices(func(a, b runtime.Object) bool { return a.GetObjectKind().GroupVersionKind().String() < b.GetObjectKind().GroupVersionKind().String() }), cmpopts.EquateEmpty(), @@ -109,8 +109,9 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { if tc.clusterTrainingRuntime != nil { clientBuilder.WithObjects(tc.clusterTrainingRuntime) } + c := clientBuilder.Build() - trainingRuntime, err := NewTrainingRuntime(ctx, clientBuilder.Build(), testingutil.AsIndex(clientBuilder)) + trainingRuntime, err := NewTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder)) if err != nil { t.Fatal(err) } @@ -120,15 +121,22 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { t.Fatal("Failed type assertion from Runtime interface to TrainingRuntime") } - clTrainingRuntime, err := NewClusterTrainingRuntime(ctx, clientBuilder.Build(), testingutil.AsIndex(clientBuilder)) + clTrainingRuntime, err := NewClusterTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder)) if err != nil { t.Fatal(err) } + objs, err := clTrainingRuntime.NewObjects(ctx, tc.trainJob) if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { t.Errorf("Unexpected error (-want,+got):\n%s", diff) } - if diff := cmp.Diff(tc.wantObjs, objs, cmpOpts...); len(diff) != 0 { + + resultObjs, err := testingutil.UnstructuredToObject(c.Scheme(), objs...) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 { t.Errorf("Unexpected objects (-want,+got):\n%s", diff) } }) diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index 0c04ff8b64..7fda722d9a 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -26,8 +26,8 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/intstr" - "sigs.k8s.io/controller-runtime/pkg/client" schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" @@ -45,7 +45,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { cases := map[string]struct { trainingRuntime *trainer.TrainingRuntime trainJob *trainer.TrainJob - wantObjs []client.Object + wantObjs []runtime.Object wantError error }{ // Test cases for the PlainML MLPolicy. @@ -73,7 +73,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), ). Obj(), - wantObjs: []client.Object{ + wantObjs: []runtime.Object{ testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). InitContainerDatasetModelInitializer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). NumNodes(30). @@ -137,7 +137,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), ). Obj(), - wantObjs: []client.Object{ + wantObjs: []runtime.Object{ testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). NumNodes(100). ContainerTrainer("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests). @@ -205,7 +205,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), ). Obj(), - wantObjs: []client.Object{ + wantObjs: []runtime.Object{ testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). NumNodes(100). ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). @@ -278,7 +278,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), ). Obj(), - wantObjs: []client.Object{ + wantObjs: []runtime.Object{ testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). NumNodes(30). ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). @@ -355,7 +355,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), ). Obj(), - wantObjs: []client.Object{ + wantObjs: []runtime.Object{ testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). NumNodes(100). ContainerTrainer("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests). @@ -418,9 +418,12 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { }, } cmpOpts := []cmp.Option{ - cmpopts.SortSlices(func(a, b client.Object) bool { + cmpopts.SortSlices(func(a, b runtime.Object) bool { return a.GetObjectKind().GroupVersionKind().String() < b.GetObjectKind().GroupVersionKind().String() }), + cmpopts.SortSlices(func(a, b corev1.EnvVar) bool { + return a.Name < b.Name + }), cmpopts.EquateEmpty(), cmpopts.SortMaps(func(a, b string) bool { return a < b }), } @@ -432,16 +435,24 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { if tc.trainingRuntime != nil { clientBuilder.WithObjects(tc.trainingRuntime) } + c := clientBuilder.Build() - trainingRuntime, err := NewTrainingRuntime(ctx, clientBuilder.Build(), testingutil.AsIndex(clientBuilder)) + trainingRuntime, err := NewTrainingRuntime(ctx, c, testingutil.AsIndex(clientBuilder)) if err != nil { t.Fatal(err) } + objs, err := trainingRuntime.NewObjects(ctx, tc.trainJob) if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { t.Errorf("Unexpected error (-want,+got):\n%s", diff) } - if diff := cmp.Diff(tc.wantObjs, objs, cmpOpts...); len(diff) != 0 { + + resultObjs, err := testingutil.UnstructuredToObject(c.Scheme(), objs...) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 { t.Errorf("Unexpected objects (-want,+got):\n%s", diff) } }) diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index 3cc9d90440..b521dd5a0f 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -27,6 +27,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" apiruntime "k8s.io/apimachinery/pkg/runtime" + k8sruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" @@ -389,17 +390,17 @@ func TestRunComponentBuilderPlugins(t *testing.T) { } cases := map[string]struct { - registry fwkplugins.Registry - runtimeInfo *runtime.Info - trainJob *trainer.TrainJob - runtimeJobTemplate client.Object - wantRuntimeInfo *runtime.Info - wantObjs []client.Object - wantError error + registry fwkplugins.Registry + runtimeInfo *runtime.Info + trainingRuntime *trainer.TrainingRuntime + trainJob *trainer.TrainJob + wantRuntimeInfo *runtime.Info + wantObjs []k8sruntime.Object + wantError error }{ "succeeded to build PodGroup and JobSet with NumNodes from TrainJob": { - registry: fwkplugins.NewRegistry(), - runtimeJobTemplate: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").DeepCopy(), + registry: fwkplugins.NewRegistry(), + trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").DeepCopy(), runtimeInfo: &runtime.Info{ RuntimePolicy: runtime.RuntimePolicy{ MLPolicy: &trainer.MLPolicy{ @@ -431,6 +432,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { }, trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "test-runtime"). Trainer( testingutil.MakeTrainJobTrainerWrapper(). NumNodes(100). @@ -468,7 +470,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { }, }, }, - wantObjs: []client.Object{ + wantObjs: []k8sruntime.Object{ testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job"). SchedulingTimeout(300). MinMember(101). // 101 replicas = 100 Trainer nodes + 1 Initializer. @@ -489,7 +491,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { "an empty registry": {}, } cmpOpts := []cmp.Option{ - cmpopts.SortSlices(func(a, b client.Object) bool { + cmpopts.SortSlices(func(a, b k8sruntime.Object) bool { return a.GetObjectKind().GroupVersionKind().String() < b.GetObjectKind().GroupVersionKind().String() }), cmpopts.EquateEmpty(), @@ -498,9 +500,14 @@ func TestRunComponentBuilderPlugins(t *testing.T) { t.Run(name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) + clientBuilder := testingutil.NewClientBuilder() + if tc.trainingRuntime != nil { + clientBuilder.WithObjects(tc.trainingRuntime) + } + c := clientBuilder.Build() - fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder)) if err != nil { t.Fatal(err) } @@ -512,13 +519,21 @@ func TestRunComponentBuilderPlugins(t *testing.T) { t.Fatal(err) } objs, err := fwk.RunComponentBuilderPlugins(ctx, tc.runtimeInfo, tc.trainJob) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { t.Errorf("Unexpected errors (-want,+got):\n%s", diff) } + if diff := cmp.Diff(tc.wantRuntimeInfo, tc.runtimeInfo); len(diff) != 0 { t.Errorf("Unexpected runtime.Info (-want,+got)\n%s", diff) } - if diff := cmp.Diff(tc.wantObjs, objs, cmpOpts...); len(diff) != 0 { + + resultObjs, err := testingutil.UnstructuredToObject(c.Scheme(), objs...) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 { t.Errorf("Unexpected objects (-want,+got):\n%s", diff) } }) @@ -654,14 +669,18 @@ func TestTerminalConditionPlugins(t *testing.T) { if tc.jobSet != nil { clientBuilder = clientBuilder.WithObjects(tc.jobSet) } - fwk, err := New(ctx, clientBuilder.Build(), tc.registry, testingutil.AsIndex(clientBuilder)) + c := clientBuilder.Build() + + fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder)) if err != nil { t.Fatal(err) } + gotCond, gotErr := fwk.RunTerminalConditionPlugins(ctx, tc.trainJob) if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); len(diff) != 0 { t.Errorf("Unexpected error (-want,+got):\n%s", diff) } + if diff := cmp.Diff(tc.wantCondition, gotCond); len(diff) != 0 { t.Errorf("Unexpected terminal condition (-want,+got):\n%s", diff) } diff --git a/pkg/util/testing/unstructured.go b/pkg/util/testing/unstructured.go new file mode 100644 index 0000000000..cf94ca8c46 --- /dev/null +++ b/pkg/util/testing/unstructured.go @@ -0,0 +1,63 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + "encoding/json" + "fmt" + + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" +) + +func UnstructuredToObject(s *runtime.Scheme, objects ...*unstructured.Unstructured) ([]runtime.Object, error) { + var objs []runtime.Object + for _, obj := range objects { + if o, err := toObject(s, obj); err != nil { + return nil, err + } else { + objs = append(objs, o) + } + } + return objs, nil +} + +func toObject(s *runtime.Scheme, obj runtime.Object) (runtime.Object, error) { + u, isUnstructured := obj.(runtime.Unstructured) + if !isUnstructured { + return obj, nil + } + gvk := obj.GetObjectKind().GroupVersionKind() + if !s.Recognizes(gvk) { + return obj, nil + } + + typed, err := s.New(gvk) + if err != nil { + return nil, fmt.Errorf("scheme recognizes %s but failed to produce an object for it: %w", gvk, err) + } + + raw, err := json.Marshal(u) + if err != nil { + return nil, fmt.Errorf("failed to serialize %T: %w", raw, err) + } + if err := json.Unmarshal(raw, typed); err != nil { + return nil, fmt.Errorf("failed to unmarshal the content of %T into %T: %w", u, typed, err) + } + + return typed, nil +} diff --git a/pkg/util/testing/wrapper.go b/pkg/util/testing/wrapper.go index f2d735b12f..67eb9a566b 100644 --- a/pkg/util/testing/wrapper.go +++ b/pkg/util/testing/wrapper.go @@ -204,7 +204,6 @@ func (j *JobSetWrapper) InitContainerDatasetInitializerEnvFrom(envFrom []corev1. for k, container := range rJob.Template.Spec.Template.Spec.InitContainers { if container.Name == constants.ContainerDatasetInitializer { j.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[k].EnvFrom = envFrom - } } } @@ -218,7 +217,6 @@ func (j *JobSetWrapper) InitContainerModelInitializerEnv(env []corev1.EnvVar) *J for k, container := range rJob.Template.Spec.Template.Spec.InitContainers { if container.Name == constants.ContainerModelInitializer { j.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[k].Env = env - } } } @@ -232,7 +230,6 @@ func (j *JobSetWrapper) InitContainerModelInitializerEnvFrom(envFrom []corev1.En for k, container := range rJob.Template.Spec.Template.Spec.InitContainers { if container.Name == constants.ContainerModelInitializer { j.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[k].EnvFrom = envFrom - } } } diff --git a/pkg/webhooks/trainingruntime_webhook_test.go b/pkg/webhooks/trainingruntime_webhook_test.go index 9423e46d48..04710b5225 100644 --- a/pkg/webhooks/trainingruntime_webhook_test.go +++ b/pkg/webhooks/trainingruntime_webhook_test.go @@ -23,8 +23,6 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "k8s.io/apimachinery/pkg/util/validation/field" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" - - testingutil "github.com/kubeflow/trainer/pkg/util/testing" ) func TestValidateReplicatedJobs(t *testing.T) { @@ -33,14 +31,24 @@ func TestValidateReplicatedJobs(t *testing.T) { wantError field.ErrorList }{ "valid replicatedJobs": { - rJobs: testingutil.MakeJobSetWrapper("ns", "valid"). - Replicas(1). - Obj().Spec.ReplicatedJobs, + rJobs: []jobsetv1alpha2.ReplicatedJob{ + { + Replicas: 1, + }, + { + Replicas: 1, + }, + }, }, "invalid replicas": { - rJobs: testingutil.MakeJobSetWrapper("ns", "valid"). - Replicas(2). - Obj().Spec.ReplicatedJobs, + rJobs: []jobsetv1alpha2.ReplicatedJob{ + { + Replicas: 2, + }, + { + Replicas: 2, + }, + }, wantError: field.ErrorList{ field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(0).Child("replicas"), "2", ""), From a5a888352adce141db90e42de2183765e390b46b Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 14 Feb 2025 18:47:05 +0100 Subject: [PATCH 08/17] Fix JobSet resource requirements Signed-off-by: Antonin Stefanutti --- pkg/runtime/framework/plugins/jobset/builder.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pkg/runtime/framework/plugins/jobset/builder.go b/pkg/runtime/framework/plugins/jobset/builder.go index 706cedbdfe..515cc679dc 100644 --- a/pkg/runtime/framework/plugins/jobset/builder.go +++ b/pkg/runtime/framework/plugins/jobset/builder.go @@ -165,11 +165,17 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build if args := trainJob.Spec.Trainer.Args; args != nil { b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Args = args } - if resourcesPerNode := trainJob.Spec.Trainer.ResourcesPerNode; resourcesPerNode != nil { + if resourcesPerNode := trainJob.Spec.Trainer.ResourcesPerNode; resourcesPerNode != nil && + (resourcesPerNode.Limits != nil || resourcesPerNode.Requests != nil) { + requirements := corev1ac.ResourceRequirements() + if limits := resourcesPerNode.Limits; limits != nil { + requirements.WithLimits(limits) + } + if requests := resourcesPerNode.Requests; requests != nil { + requirements.WithRequests(requests) + } b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j]. - WithResources(corev1ac.ResourceRequirements(). - WithRequests(resourcesPerNode.Requests). - WithLimits(resourcesPerNode.Limits)) + WithResources(requirements) } } // Update values from the Info object. From d334084fda9bc6f377ca5a23eb0c32ba466443ba Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 14 Feb 2025 20:00:55 +0100 Subject: [PATCH 09/17] Resolve build issues with launcher job Signed-off-by: Antonin Stefanutti --- .../framework/plugins/jobset/builder.go | 64 ++++++++++--------- pkg/util/apply/apply.go | 22 +++++++ 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/pkg/runtime/framework/plugins/jobset/builder.go b/pkg/runtime/framework/plugins/jobset/builder.go index 515cc679dc..93613e01ce 100644 --- a/pkg/runtime/framework/plugins/jobset/builder.go +++ b/pkg/runtime/framework/plugins/jobset/builder.go @@ -18,7 +18,6 @@ package jobset import ( corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/util/sets" corev1ac "k8s.io/client-go/applyconfigurations/core/v1" "k8s.io/utils/ptr" jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" @@ -97,40 +96,27 @@ func (b *Builder) Initializer(trainJob *trainer.TrainJob) *Builder { // Launcher updates JobSet values for the launcher Job. func (b *Builder) Launcher(info *runtime.Info, trainJob *trainer.TrainJob) *Builder { for i, rJob := range b.Spec.ReplicatedJobs { - if rJob.Name == constants.JobLauncher { + if *rJob.Name == constants.JobLauncher { - // Update the volumes for the launcher Job. - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Volumes = append( - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Volumes, info.Trainer.Volumes...) + // Update the volumes for the Trainer Job. + upsertVolumes(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Volumes, info.Trainer.Volumes...) // Update values for the launcher container. for j, container := range rJob.Template.Spec.Template.Spec.Containers { - if container.Name == constants.ContainerLauncher { + if *container.Name == constants.ContainerLauncher { // Update values from the Info object. - if info.Trainer.Env != nil { + if env := info.Trainer.Env; env != nil { // Update JobSet envs from the Info. - envNames := sets.New[string]() - for _, env := range info.Trainer.Env { - envNames.Insert(env.Name) - } - trainerEnvs := info.Trainer.Env - // Info envs take precedence over the TrainingRuntime envs. - for _, env := range container.Env { - if !envNames.Has(env.Name) { - trainerEnvs = append(trainerEnvs, env) - } - } - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Env = trainerEnvs + upsertEnvVars(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Env, env...) } + // Update the launcher container port. - if info.Trainer.ContainerPort != nil { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports = append( - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports, *info.Trainer.ContainerPort) + if port := info.Trainer.ContainerPort; port != nil { + upsertPorts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports, *port) } // Update the launcher container volume mounts. - if info.Trainer.VolumeMounts != nil { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].VolumeMounts = append( - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].VolumeMounts, info.Trainer.VolumeMounts...) + if mounts := info.Trainer.VolumeMounts; mounts != nil { + upsertVolumeMounts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].VolumeMounts, mounts...) } } } @@ -148,8 +134,7 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build b.Spec.ReplicatedJobs[i].Template.Spec.Completions = info.Trainer.NumNodes // Update the volumes for the Trainer Job. - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Volumes = append( - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Volumes, info.Trainer.Volumes...) + upsertVolumes(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Volumes, info.Trainer.Volumes...) // Update values for the Trainer container. for j, container := range rJob.Template.Spec.Template.Spec.Containers { @@ -188,9 +173,8 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build upsertPorts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports, *port) } // Update the Trainer container volume mounts. - if info.Trainer.VolumeMounts != nil { - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].VolumeMounts = append( - b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].VolumeMounts, info.Trainer.VolumeMounts...) + if mounts := info.Trainer.VolumeMounts; mounts != nil { + upsertVolumeMounts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].VolumeMounts, mounts...) } } } @@ -231,6 +215,18 @@ func upsertPorts(portList *[]corev1ac.ContainerPortApplyConfiguration, ports ... } } +func upsertVolumes(volumeList *[]corev1ac.VolumeApplyConfiguration, volumes ...corev1.Volume) { + for _, v := range volumes { + upsert(volumeList, apply.Volume(v), byVolumeName) + } +} + +func upsertVolumeMounts(mountList *[]corev1ac.VolumeMountApplyConfiguration, mounts ...corev1.VolumeMount) { + for _, m := range mounts { + upsert(mountList, apply.VolumeMount(m), byVolumeMountName) + } +} + func byEnvVarName(a, b corev1ac.EnvVarApplyConfiguration) bool { return ptr.Equal(a.Name, b.Name) } @@ -239,6 +235,14 @@ func byContainerPortOrName(a, b corev1ac.ContainerPortApplyConfiguration) bool { return ptr.Equal(a.ContainerPort, b.ContainerPort) || ptr.Equal(a.Name, b.Name) } +func byVolumeName(a, b corev1ac.VolumeApplyConfiguration) bool { + return ptr.Equal(a.Name, b.Name) +} + +func byVolumeMountName(a, b corev1ac.VolumeMountApplyConfiguration) bool { + return ptr.Equal(a.Name, b.Name) +} + type compare[T any] func(T, T) bool func upsert[T any](items *[]T, item *T, predicate compare[T]) { diff --git a/pkg/util/apply/apply.go b/pkg/util/apply/apply.go index e8ccd13c79..4ce49d90d3 100644 --- a/pkg/util/apply/apply.go +++ b/pkg/util/apply/apply.go @@ -151,3 +151,25 @@ func Conditions(c ...metav1.Condition) []*metav1ac.ConditionApplyConfiguration { } return conditions } + +func Volume(v corev1.Volume) *corev1ac.VolumeApplyConfiguration { + volume := corev1ac.Volume().WithName(v.Name) + // FIXME + return volume +} + +func VolumeMount(m corev1.VolumeMount) *corev1ac.VolumeMountApplyConfiguration { + volumeMount := corev1ac.VolumeMount().WithName(m.Name) + if m.MountPath != "" { + volumeMount.WithMountPath(m.MountPath) + } + return volumeMount +} + +func VolumeMounts(v ...corev1.VolumeMount) []*corev1ac.VolumeMountApplyConfiguration { + var mounts []*corev1ac.VolumeMountApplyConfiguration + for _, mount := range v { + mounts = append(mounts, VolumeMount(mount)) + } + return mounts +} From 2a5c276fba0bafa4d926b3c53297157257c4044e Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Mon, 17 Feb 2025 09:16:03 +0100 Subject: [PATCH 10/17] Use apply config for MPI ConfigMap and Secret Signed-off-by: Antonin Stefanutti --- pkg/runtime/framework/plugins/mpi/mpi.go | 99 ++++++++---------------- 1 file changed, 33 insertions(+), 66 deletions(-) diff --git a/pkg/runtime/framework/plugins/mpi/mpi.go b/pkg/runtime/framework/plugins/mpi/mpi.go index 03dcbd0c1a..b357cf95b8 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi.go +++ b/pkg/runtime/framework/plugins/mpi/mpi.go @@ -25,22 +25,19 @@ import ( "crypto/x509" "encoding/pem" "fmt" - "maps" "strconv" "golang.org/x/crypto/ssh" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/equality" - apierrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/validation/field" + corev1ac "k8s.io/client-go/applyconfigurations/core/v1" + metav1ac "k8s.io/client-go/applyconfigurations/meta/v1" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" - ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" @@ -149,7 +146,7 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er // TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other. // Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940 // TODO (andreyvelich): Support other MPI implementations. - infoEnvs := []corev1.EnvVar{} + var infoEnvs []corev1.EnvVar switch info.RuntimePolicy.MLPolicy.MPI.MPIImplementation { case trainer.MPIImplementationOpenMPI: infoEnvs = append(infoEnvs, []corev1.EnvVar{ @@ -209,17 +206,17 @@ func (m *MPI) ReconcilerBuilders() []runtime.ReconcilerBuilder { } } -func (m *MPI) Build(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) { +func (m *MPI) Build(_ context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) { if info == nil || info.RuntimePolicy.MLPolicy == nil || info.RuntimePolicy.MLPolicy.MPI == nil { return nil, nil } - secret, err := m.buildSSHAuthSecret(ctx, trainJob) + secret, err := m.buildSSHAuthSecret(trainJob) if err != nil { return nil, fmt.Errorf("failed to build Secret with SSH auth keys. Error: %v", err) } - configMap, err := m.buildHostFileConfigMap(ctx, info, trainJob) + configMap, err := m.buildHostFileConfigMap(info, trainJob) if err != nil { return nil, fmt.Errorf("failed to build ConfigMap with hostfile. Error: %v", err) } @@ -227,7 +224,7 @@ func (m *MPI) Build(ctx context.Context, info *runtime.Info, trainJob *trainer.T return []any{secret, configMap}, nil } -func (m *MPI) buildSSHAuthSecret(ctx context.Context, trainJob *trainer.TrainJob) (*corev1.Secret, error) { +func (m *MPI) buildSSHAuthSecret(trainJob *trainer.TrainJob) (*corev1ac.SecretApplyConfiguration, error) { // Generate SSH private and public keys. privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { @@ -250,40 +247,25 @@ func (m *MPI) buildSSHAuthSecret(ctx context.Context, trainJob *trainer.TrainJob } // Create Secret to store ssh keys. - secret := &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: trainJob.Name + constants.MPISSHAuthSecretSuffix, - Namespace: trainJob.Namespace, - }, - Type: corev1.SecretTypeSSHAuth, - Data: map[string][]byte{ + secret := corev1ac.Secret(trainJob.Name+constants.MPISSHAuthSecretSuffix, trainJob.Namespace). + WithType(corev1.SecretTypeSSHAuth). + WithData(map[string][]byte{ corev1.SSHAuthPrivateKey: privatePEM, constants.MPISSHPublicKey: ssh.MarshalAuthorizedKey(publicKey), - }, - } - if err := ctrlutil.SetControllerReference(trainJob, secret, m.scheme); err != nil { - return nil, err - } - oldSecret := &corev1.Secret{} - if err := m.client.Get(ctx, client.ObjectKeyFromObject(secret), oldSecret); err != nil { - if !apierrors.IsNotFound(err) { - return nil, err - } - oldSecret = nil - } - if needsCreateOrUpdateSecret(oldSecret, secret, ptr.Deref(trainJob.Spec.Suspend, false)) { - return secret, nil - } - return nil, nil -} + }) + + secret.WithOwnerReferences(metav1ac.OwnerReference(). + WithAPIVersion(trainer.GroupVersion.String()). + WithKind(trainer.TrainJobKind). + WithName(trainJob.Name). + WithUID(trainJob.UID). + WithController(true). + WithBlockOwnerDeletion(true)) -func needsCreateOrUpdateSecret(old, new *corev1.Secret, trainJobIsSuspended bool) bool { - return old == nil || - trainJobIsSuspended && - (!equality.Semantic.DeepEqual(old.Data, new.Data) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations)) + return secret, nil } -func (m *MPI) buildHostFileConfigMap(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) (*corev1.ConfigMap, error) { +func (m *MPI) buildHostFileConfigMap(info *runtime.Info, trainJob *trainer.TrainJob) (*corev1ac.ConfigMapApplyConfiguration, error) { // Generate hostfile for the MPI communication. var hostfile bytes.Buffer // TODO (andreyvelich): Support other MPI implementations. @@ -296,33 +278,18 @@ func (m *MPI) buildHostFileConfigMap(ctx context.Context, info *runtime.Info, tr } // Create ConfigMap to store hostfile. - configMap := &corev1.ConfigMap{ - ObjectMeta: metav1.ObjectMeta{ - Name: trainJob.Name + constants.MPIHostfileConfigMapSuffix, - Namespace: trainJob.Namespace, - }, - Data: map[string]string{ + configMap := corev1ac.ConfigMap(trainJob.Name+constants.MPIHostfileConfigMapSuffix, trainJob.Namespace). + WithData(map[string]string{ constants.MPIHostfileName: hostfile.String(), - }, - } - if err := ctrlutil.SetControllerReference(trainJob, configMap, m.scheme); err != nil { - return nil, err - } - oldConfigMap := &corev1.ConfigMap{} - if err := m.client.Get(ctx, client.ObjectKeyFromObject(configMap), oldConfigMap); err != nil { - if !apierrors.IsNotFound(err) { - return nil, err - } - oldConfigMap = nil - } - if needsCreateOrUpdateConfigMap(oldConfigMap, configMap, ptr.Deref(trainJob.Spec.Suspend, false)) { - return configMap, nil - } - return nil, nil -} + }) + + configMap.WithOwnerReferences(metav1ac.OwnerReference(). + WithAPIVersion(trainer.GroupVersion.String()). + WithKind(trainer.TrainJobKind). + WithName(trainJob.Name). + WithUID(trainJob.UID). + WithController(true). + WithBlockOwnerDeletion(true)) -func needsCreateOrUpdateConfigMap(old, new *corev1.ConfigMap, trainJobIsSuspended bool) bool { - return old == nil || - trainJobIsSuspended && - (!equality.Semantic.DeepEqual(old.Data, new.Data) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations)) + return configMap, nil } From 86c0196fa252f7183b7615d2d361ff0421cc2a5d Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Mon, 17 Feb 2025 09:17:04 +0100 Subject: [PATCH 11/17] ComponentBuilderPlugin now returns an array Signed-off-by: Antonin Stefanutti --- pkg/runtime/framework/core/framework.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pkg/runtime/framework/core/framework.go b/pkg/runtime/framework/core/framework.go index 8c334729ca..c9d1db6da2 100644 --- a/pkg/runtime/framework/core/framework.go +++ b/pkg/runtime/framework/core/framework.go @@ -117,13 +117,15 @@ func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *trainer.TrainJob) func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) { var objs []*unstructured.Unstructured for _, plugin := range f.componentBuilderPlugins { - if component, err := plugin.Build(ctx, info, trainJob); err != nil { + if components, err := plugin.Build(ctx, info, trainJob); err != nil { return nil, err - } else if component != nil { - if content, err := k8sruntime.DefaultUnstructuredConverter.ToUnstructured(component); err != nil { - return nil, err - } else { - objs = append(objs, &unstructured.Unstructured{Object: content}) + } else if components != nil { + for _, component := range components { + if content, err := k8sruntime.DefaultUnstructuredConverter.ToUnstructured(component); err != nil { + return nil, err + } else { + objs = append(objs, &unstructured.Unstructured{Object: content}) + } } } } From 099e9e4fe7e8e518afd13fbb9a06e707c3acd44e Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Mon, 17 Feb 2025 11:09:28 +0100 Subject: [PATCH 12/17] Use plain apply configurations instead of unstructured Signed-off-by: Antonin Stefanutti --- pkg/controller/trainjob_controller.go | 29 ++++++++++++++----- pkg/runtime/core/clustertrainingruntime.go | 3 +- .../core/clustertrainingruntime_test.go | 2 +- pkg/runtime/core/trainingruntime.go | 5 ++-- pkg/runtime/core/trainingruntime_test.go | 2 +- pkg/runtime/framework/core/framework.go | 14 ++------- pkg/runtime/framework/core/framework_test.go | 2 +- pkg/runtime/interface.go | 7 +++-- .../testing/{unstructured.go => runtime.go} | 22 +++++++------- 9 files changed, 48 insertions(+), 38 deletions(-) rename pkg/util/testing/{unstructured.go => runtime.go} (72%) diff --git a/pkg/controller/trainjob_controller.go b/pkg/controller/trainjob_controller.go index cb3438b19f..6ea57a5a18 100644 --- a/pkg/controller/trainjob_controller.go +++ b/pkg/controller/trainjob_controller.go @@ -26,6 +26,8 @@ import ( "k8s.io/apimachinery/pkg/api/equality" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + k8sruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/tools/record" "k8s.io/klog/v2" @@ -105,26 +107,39 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) (objsOpState, error) { log := ctrl.LoggerFrom(ctx) - objs, err := runtime.NewObjects(ctx, trainJob) + objects, err := runtime.NewObjects(ctx, trainJob) if err != nil { return buildFailed, err } - for _, obj := range objs { + for _, object := range objects { + // TODO (astefanutti): Remove conversion to unstructured when the runtime.ApplyConfiguration + // interface becomes available and first-class SSA method is added to the controller-runtime + // client. See https://github.com/kubernetes/kubernetes/pull/129313 + var obj client.Object + if o, ok := object.(client.Object); ok { + obj = o + } else { + if u, err := k8sruntime.DefaultUnstructuredConverter.ToUnstructured(object); err != nil { + return buildFailed, err + } else { + obj = &unstructured.Unstructured{Object: u} + } + } + + if err := r.client.Patch(ctx, obj, client.Apply, client.FieldOwner("trainer"), client.ForceOwnership); err != nil { + return buildFailed, err + } + var gvk schema.GroupVersionKind if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil { return buildFailed, err } - logKeysAndValues := []any{ "groupVersionKind", gvk.String(), "namespace", obj.GetNamespace(), "name", obj.GetName(), } - if err := r.client.Patch(ctx, obj, client.Apply, client.FieldOwner("trainer"), client.ForceOwnership); err != nil { - return buildFailed, err - } - log.V(5).Info("Succeeded to update object", logKeysAndValues...) } return creationSucceeded, nil diff --git a/pkg/runtime/core/clustertrainingruntime.go b/pkg/runtime/core/clustertrainingruntime.go index 8bea04c382..cf5df52fef 100644 --- a/pkg/runtime/core/clustertrainingruntime.go +++ b/pkg/runtime/core/clustertrainingruntime.go @@ -22,7 +22,6 @@ import ( "fmt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" @@ -53,7 +52,7 @@ func NewClusterTrainingRuntime(context.Context, client.Client, client.FieldIndex }, nil } -func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) { +func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]any, error) { var clTrainingRuntime trainer.ClusterTrainingRuntime if err := r.client.Get(ctx, client.ObjectKey{Name: trainJob.Spec.RuntimeRef.Name}, &clTrainingRuntime); err != nil { return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedClusterTrainingRuntime, err) diff --git a/pkg/runtime/core/clustertrainingruntime_test.go b/pkg/runtime/core/clustertrainingruntime_test.go index 658ae20f66..f63b9443db 100644 --- a/pkg/runtime/core/clustertrainingruntime_test.go +++ b/pkg/runtime/core/clustertrainingruntime_test.go @@ -131,7 +131,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { t.Errorf("Unexpected error (-want,+got):\n%s", diff) } - resultObjs, err := testingutil.UnstructuredToObject(c.Scheme(), objs...) + resultObjs, err := testingutil.ToObject(c.Scheme(), objs...) if err != nil { t.Fatal(err) } diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index 157e742455..9cf182fec0 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -22,7 +22,6 @@ import ( "fmt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" @@ -71,7 +70,7 @@ func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.Fie return trainingRuntimeFactory, nil } -func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) { +func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]any, error) { var trainingRuntime trainer.TrainingRuntime err := r.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.RuntimeRef.Name}, &trainingRuntime) if err != nil { @@ -82,7 +81,7 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.Trai func (r *TrainingRuntime) buildObjects( ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy, -) ([]*unstructured.Unstructured, error) { +) ([]any, error) { propagationLabels := jobSetTemplateSpec.Labels if propagationLabels == nil && trainJob.Spec.Labels != nil { propagationLabels = make(map[string]string, len(trainJob.Spec.Labels)) diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index 7fda722d9a..afda250fdf 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -447,7 +447,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { t.Errorf("Unexpected error (-want,+got):\n%s", diff) } - resultObjs, err := testingutil.UnstructuredToObject(c.Scheme(), objs...) + resultObjs, err := testingutil.ToObject(c.Scheme(), objs...) if err != nil { t.Fatal(err) } diff --git a/pkg/runtime/framework/core/framework.go b/pkg/runtime/framework/core/framework.go index c9d1db6da2..476c2d333a 100644 --- a/pkg/runtime/framework/core/framework.go +++ b/pkg/runtime/framework/core/framework.go @@ -21,8 +21,6 @@ import ( "errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" - k8sruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -114,19 +112,13 @@ func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *trainer.TrainJob) return aggregatedWarnings, aggregatedErrors } -func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) { - var objs []*unstructured.Unstructured +func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) { + var objs []any for _, plugin := range f.componentBuilderPlugins { if components, err := plugin.Build(ctx, info, trainJob); err != nil { return nil, err } else if components != nil { - for _, component := range components { - if content, err := k8sruntime.DefaultUnstructuredConverter.ToUnstructured(component); err != nil { - return nil, err - } else { - objs = append(objs, &unstructured.Unstructured{Object: content}) - } - } + objs = append(objs, components...) } } return objs, nil diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index b521dd5a0f..d0f7c4812c 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -528,7 +528,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { t.Errorf("Unexpected runtime.Info (-want,+got)\n%s", diff) } - resultObjs, err := testingutil.UnstructuredToObject(c.Scheme(), objs...) + resultObjs, err := testingutil.ToObject(c.Scheme(), objs...) if err != nil { t.Fatal(err) } diff --git a/pkg/runtime/interface.go b/pkg/runtime/interface.go index 3b045ee0a0..8b1fba7af9 100644 --- a/pkg/runtime/interface.go +++ b/pkg/runtime/interface.go @@ -20,7 +20,6 @@ import ( "context" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" @@ -33,7 +32,11 @@ import ( type ReconcilerBuilder func(*builder.Builder, client.Client, cache.Cache) *builder.Builder type Runtime interface { - NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) + + // TODO (astefanutti): Change the return type from []any to []runtime.ApplyConfiguration when + // https://github.com/kubernetes/kubernetes/pull/129313 becomes available + + NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]any, error) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) EventHandlerRegistrars() []ReconcilerBuilder ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) diff --git a/pkg/util/testing/unstructured.go b/pkg/util/testing/runtime.go similarity index 72% rename from pkg/util/testing/unstructured.go rename to pkg/util/testing/runtime.go index cf94ca8c46..1cf4abd7d0 100644 --- a/pkg/util/testing/unstructured.go +++ b/pkg/util/testing/runtime.go @@ -24,7 +24,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" ) -func UnstructuredToObject(s *runtime.Scheme, objects ...*unstructured.Unstructured) ([]runtime.Object, error) { +func ToObject(s *runtime.Scheme, objects ...any) ([]runtime.Object, error) { var objs []runtime.Object for _, obj := range objects { if o, err := toObject(s, obj); err != nil { @@ -36,21 +36,24 @@ func UnstructuredToObject(s *runtime.Scheme, objects ...*unstructured.Unstructur return objs, nil } -func toObject(s *runtime.Scheme, obj runtime.Object) (runtime.Object, error) { - u, isUnstructured := obj.(runtime.Unstructured) - if !isUnstructured { - return obj, nil +func toObject(s *runtime.Scheme, obj any) (runtime.Object, error) { + if o, ok := obj.(runtime.Object); ok { + return o, nil } - gvk := obj.GetObjectKind().GroupVersionKind() + var u *unstructured.Unstructured + if o, err := runtime.DefaultUnstructuredConverter.ToUnstructured(obj); err != nil { + return nil, err + } else { + u = &unstructured.Unstructured{Object: o} + } + gvk := u.GetObjectKind().GroupVersionKind() if !s.Recognizes(gvk) { - return obj, nil + return nil, fmt.Errorf("%s is not a recognized schema", gvk.GroupVersion().String()) } - typed, err := s.New(gvk) if err != nil { return nil, fmt.Errorf("scheme recognizes %s but failed to produce an object for it: %w", gvk, err) } - raw, err := json.Marshal(u) if err != nil { return nil, fmt.Errorf("failed to serialize %T: %w", raw, err) @@ -58,6 +61,5 @@ func toObject(s *runtime.Scheme, obj runtime.Object) (runtime.Object, error) { if err := json.Unmarshal(raw, typed); err != nil { return nil, fmt.Errorf("failed to unmarshal the content of %T into %T: %w", u, typed, err) } - return typed, nil } From d38ba49bf301ed08bff232bc8dd388200ab9ed06 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Mon, 17 Feb 2025 14:09:50 +0100 Subject: [PATCH 13/17] Use apply config in EnforceMLPolicy plugins Signed-off-by: Antonin Stefanutti --- .../framework/plugins/jobset/builder.go | 88 ++-------- pkg/runtime/framework/plugins/mpi/mpi.go | 108 +++++------- .../framework/plugins/plainml/plainml.go | 3 +- pkg/runtime/framework/plugins/torch/torch.go | 71 +++----- pkg/runtime/runtime.go | 9 +- pkg/util/apply/apply.go | 156 ++++++------------ 6 files changed, 139 insertions(+), 296 deletions(-) diff --git a/pkg/runtime/framework/plugins/jobset/builder.go b/pkg/runtime/framework/plugins/jobset/builder.go index 93613e01ce..c59326e35e 100644 --- a/pkg/runtime/framework/plugins/jobset/builder.go +++ b/pkg/runtime/framework/plugins/jobset/builder.go @@ -17,9 +17,7 @@ limitations under the License. package jobset import ( - corev1 "k8s.io/api/core/v1" corev1ac "k8s.io/client-go/applyconfigurations/core/v1" - "k8s.io/utils/ptr" jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" @@ -51,12 +49,11 @@ func (b *Builder) Initializer(trainJob *trainer.TrainJob) *Builder { env := &b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j].Env // Update the dataset initializer envs. if storageUri := trainJob.Spec.DatasetConfig.StorageUri; storageUri != nil { - upsertEnvVars(env, corev1.EnvVar{ - Name: InitializerEnvStorageUri, - Value: *storageUri, - }) + apply.UpsertEnvVar(env, corev1ac.EnvVar(). + WithName(InitializerEnvStorageUri). + WithValue(*storageUri)) } - upsertEnvVars(env, trainJob.Spec.DatasetConfig.Env...) + apply.UpsertEnvVars(env, apply.EnvVars(trainJob.Spec.DatasetConfig.Env...)) // Update the dataset initializer secret reference. if trainJob.Spec.DatasetConfig.SecretRef != nil { b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j]. @@ -73,12 +70,11 @@ func (b *Builder) Initializer(trainJob *trainer.TrainJob) *Builder { // Update the model initializer envs. env := &b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j].Env if storageUri := trainJob.Spec.ModelConfig.Input.StorageUri; storageUri != nil { - upsertEnvVars(env, corev1.EnvVar{ - Name: InitializerEnvStorageUri, - Value: *storageUri, - }) + apply.UpsertEnvVar(env, corev1ac.EnvVar(). + WithName(InitializerEnvStorageUri). + WithValue(*storageUri)) } - upsertEnvVars(env, trainJob.Spec.ModelConfig.Input.Env...) + apply.UpsertEnvVars(env, apply.EnvVars(trainJob.Spec.ModelConfig.Input.Env...)) // Update the model initializer secret reference. if trainJob.Spec.ModelConfig.Input.SecretRef != nil { b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.InitContainers[j]. @@ -99,7 +95,7 @@ func (b *Builder) Launcher(info *runtime.Info, trainJob *trainer.TrainJob) *Buil if *rJob.Name == constants.JobLauncher { // Update the volumes for the Trainer Job. - upsertVolumes(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Volumes, info.Trainer.Volumes...) + apply.UpsertVolumes(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Volumes, info.Trainer.Volumes) // Update values for the launcher container. for j, container := range rJob.Template.Spec.Template.Spec.Containers { @@ -107,16 +103,16 @@ func (b *Builder) Launcher(info *runtime.Info, trainJob *trainer.TrainJob) *Buil // Update values from the Info object. if env := info.Trainer.Env; env != nil { // Update JobSet envs from the Info. - upsertEnvVars(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Env, env...) + apply.UpsertEnvVars(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Env, env) } // Update the launcher container port. if port := info.Trainer.ContainerPort; port != nil { - upsertPorts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports, *port) + apply.UpsertPort(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports, port) } // Update the launcher container volume mounts. if mounts := info.Trainer.VolumeMounts; mounts != nil { - upsertVolumeMounts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].VolumeMounts, mounts...) + apply.UpsertVolumeMounts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].VolumeMounts, mounts) } } } @@ -134,7 +130,7 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build b.Spec.ReplicatedJobs[i].Template.Spec.Completions = info.Trainer.NumNodes // Update the volumes for the Trainer Job. - upsertVolumes(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Volumes, info.Trainer.Volumes...) + apply.UpsertVolumes(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Volumes, info.Trainer.Volumes) // Update values for the Trainer container. for j, container := range rJob.Template.Spec.Template.Spec.Containers { @@ -166,15 +162,15 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build // Update values from the Info object. if env := info.Trainer.Env; env != nil { // Update JobSet envs from the Info. - upsertEnvVars(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Env, env...) + apply.UpsertEnvVars(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Env, env) } // Update the Trainer container port. if port := info.Trainer.ContainerPort; port != nil { - upsertPorts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports, *port) + apply.UpsertPort(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Ports, port) } // Update the Trainer container volume mounts. if mounts := info.Trainer.VolumeMounts; mounts != nil { - upsertVolumeMounts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].VolumeMounts, mounts...) + apply.UpsertVolumeMounts(&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].VolumeMounts, mounts) } } } @@ -202,55 +198,3 @@ func (b *Builder) Suspend(suspend *bool) *Builder { func (b *Builder) Build() *jobsetv1alpha2ac.JobSetApplyConfiguration { return b.JobSetApplyConfiguration } - -func upsertEnvVars(envVarList *[]corev1ac.EnvVarApplyConfiguration, envVars ...corev1.EnvVar) { - for _, e := range envVars { - upsert(envVarList, apply.EnvVar(e), byEnvVarName) - } -} - -func upsertPorts(portList *[]corev1ac.ContainerPortApplyConfiguration, ports ...corev1.ContainerPort) { - for _, p := range ports { - upsert(portList, apply.ContainerPort(p), byContainerPortOrName) - } -} - -func upsertVolumes(volumeList *[]corev1ac.VolumeApplyConfiguration, volumes ...corev1.Volume) { - for _, v := range volumes { - upsert(volumeList, apply.Volume(v), byVolumeName) - } -} - -func upsertVolumeMounts(mountList *[]corev1ac.VolumeMountApplyConfiguration, mounts ...corev1.VolumeMount) { - for _, m := range mounts { - upsert(mountList, apply.VolumeMount(m), byVolumeMountName) - } -} - -func byEnvVarName(a, b corev1ac.EnvVarApplyConfiguration) bool { - return ptr.Equal(a.Name, b.Name) -} - -func byContainerPortOrName(a, b corev1ac.ContainerPortApplyConfiguration) bool { - return ptr.Equal(a.ContainerPort, b.ContainerPort) || ptr.Equal(a.Name, b.Name) -} - -func byVolumeName(a, b corev1ac.VolumeApplyConfiguration) bool { - return ptr.Equal(a.Name, b.Name) -} - -func byVolumeMountName(a, b corev1ac.VolumeMountApplyConfiguration) bool { - return ptr.Equal(a.Name, b.Name) -} - -type compare[T any] func(T, T) bool - -func upsert[T any](items *[]T, item *T, predicate compare[T]) { - for i, t := range *items { - if predicate(t, *item) { - (*items)[i] = *item - return - } - } - *items = append(*items, *item) -} diff --git a/pkg/runtime/framework/plugins/mpi/mpi.go b/pkg/runtime/framework/plugins/mpi/mpi.go index b357cf95b8..ec81da74db 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi.go +++ b/pkg/runtime/framework/plugins/mpi/mpi.go @@ -30,7 +30,6 @@ import ( "golang.org/x/crypto/ssh" corev1 "k8s.io/api/core/v1" apiruntime "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/validation/field" corev1ac "k8s.io/client-go/applyconfigurations/core/v1" metav1ac "k8s.io/client-go/applyconfigurations/meta/v1" @@ -44,6 +43,7 @@ import ( "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" "github.com/kubeflow/trainer/pkg/runtime/framework" + "github.com/kubeflow/trainer/pkg/util/apply" ) type MPI struct { @@ -96,49 +96,35 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er info.Trainer.NumProcPerNode = numProcPerNode // Add Secret and ConfigMap volumes to the Info object - info.Volumes = []corev1.Volume{ - { - Name: constants.MPISSHAuthVolumeName, - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: trainJob.Name + constants.MPISSHAuthSecretSuffix, - Items: []corev1.KeyToPath{ - { - Key: corev1.SSHAuthPrivateKey, - Path: constants.MPISSHPrivateKeyFile, - }, - { - Key: constants.MPISSHPublicKey, - Path: constants.MPISSHPublicKeyFile, - }, - { - Key: constants.MPISSHPublicKey, - Path: constants.MPISSHAuthorizedKeys, - }, - }, - }, - }, - }, - { - Name: constants.MPIHostfileVolumeName, - VolumeSource: corev1.VolumeSource{ - ConfigMap: &corev1.ConfigMapVolumeSource{ - LocalObjectReference: corev1.LocalObjectReference{ - Name: trainJob.Name + constants.MPIHostfileConfigMapSuffix, - }, - }, - }, - }, + info.Volumes = []corev1ac.VolumeApplyConfiguration{ + *corev1ac.Volume(). + WithName(constants.MPISSHAuthVolumeName). + WithSecret(corev1ac.SecretVolumeSource(). + WithSecretName(trainJob.Name+constants.MPISSHAuthSecretSuffix). + WithItems( + corev1ac.KeyToPath(). + WithKey(corev1.SSHAuthPrivateKey). + WithPath(constants.MPISSHPrivateKeyFile), + corev1ac.KeyToPath(). + WithKey(constants.MPISSHPublicKey). + WithPath(constants.MPISSHPublicKeyFile), + corev1ac.KeyToPath(). + WithKey(constants.MPISSHPublicKey). + WithPath(constants.MPISSHAuthorizedKeys), + )), + *corev1ac.Volume(). + WithName(constants.MPIHostfileVolumeName). + WithConfigMap(corev1ac.ConfigMapVolumeSource(). + WithName(trainJob.Name + constants.MPIHostfileConfigMapSuffix)), } - info.VolumeMounts = []corev1.VolumeMount{ - { - Name: constants.MPISSHAuthVolumeName, - MountPath: info.RuntimePolicy.MLPolicy.MPI.SSHAuthMountPath, - }, - { - Name: constants.MPIHostfileVolumeName, - MountPath: constants.MPIHostfileDir, - }, + + info.VolumeMounts = []corev1ac.VolumeMountApplyConfiguration{ + *corev1ac.VolumeMount(). + WithName(constants.MPISSHAuthVolumeName). + WithMountPath(info.RuntimePolicy.MLPolicy.MPI.SSHAuthMountPath), + *corev1ac.VolumeMount(). + WithName(constants.MPIHostfileVolumeName). + WithMountPath(constants.MPIHostfileDir), } // Update envs for Info object. @@ -146,39 +132,23 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er // TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other. // Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940 // TODO (andreyvelich): Support other MPI implementations. - var infoEnvs []corev1.EnvVar + + if trainJob.Spec.Trainer != nil { + info.Trainer.Env = apply.EnvVars(trainJob.Spec.Trainer.Env...) + } + switch info.RuntimePolicy.MLPolicy.MPI.MPIImplementation { case trainer.MPIImplementationOpenMPI: - infoEnvs = append(infoEnvs, []corev1.EnvVar{ - { - Name: constants.OpenMPIEnvHostFileLocation, - Value: fmt.Sprintf("%s/%s", constants.MPIHostfileDir, constants.MPIHostfileName), - }}...) + apply.UpsertEnvVar(&info.Trainer.Env, corev1ac.EnvVar(). + WithName(constants.OpenMPIEnvHostFileLocation). + WithValue(fmt.Sprintf("%s/%s", constants.MPIHostfileDir, constants.MPIHostfileName))) default: return fmt.Errorf("MPI implementation for %s doesn't supported", info.RuntimePolicy.MLPolicy.MPI.MPIImplementation) } - // Set for all Info envs. - envNames := sets.New[string]() - for _, env := range infoEnvs { - envNames.Insert(env.Name) - } - // Info envs take precedence over TrainJob envs. - if trainJob.Spec.Trainer != nil { - for _, env := range trainJob.Spec.Trainer.Env { - if !envNames.Has(env.Name) { - info.Trainer.Env = append(info.Trainer.Env, corev1.EnvVar{Name: env.Name, Value: env.Value}) - } - } - } - - // Insert MPI distributed envs into the list end. - info.Trainer.Env = append(info.Trainer.Env, infoEnvs...) - // Add container port for the headless service. - info.Trainer.ContainerPort = &corev1.ContainerPort{ - ContainerPort: constants.ContainerTrainerPort, - } + info.Trainer.ContainerPort = corev1ac.ContainerPort(). + WithContainerPort(constants.ContainerTrainerPort) // Update total Pod requests for the PodGroupPolicy plugin. for rName := range info.TotalRequests { diff --git a/pkg/runtime/framework/plugins/plainml/plainml.go b/pkg/runtime/framework/plugins/plainml/plainml.go index 74037124fa..b389b97d58 100644 --- a/pkg/runtime/framework/plugins/plainml/plainml.go +++ b/pkg/runtime/framework/plugins/plainml/plainml.go @@ -26,6 +26,7 @@ import ( "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" "github.com/kubeflow/trainer/pkg/runtime/framework" + "github.com/kubeflow/trainer/pkg/util/apply" ) var _ framework.EnforceMLPolicyPlugin = (*PlainML)(nil) @@ -57,7 +58,7 @@ func (p *PlainML) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob // Add envs from the TrainJob. if trainJob.Spec.Trainer != nil { - info.Trainer.Env = append(info.Trainer.Env, trainJob.Spec.Trainer.Env...) + apply.UpsertEnvVars(&info.Trainer.Env, apply.EnvVars(trainJob.Spec.Trainer.Env...)) } // Update total Pod requests for the PodGroupPolicy plugin. diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index af9a04c456..bf715dd215 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -20,10 +20,9 @@ import ( "context" "fmt" - corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/intstr" - "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/validation/field" + corev1ac "k8s.io/client-go/applyconfigurations/core/v1" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -32,6 +31,7 @@ import ( "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" "github.com/kubeflow/trainer/pkg/runtime/framework" + "github.com/kubeflow/trainer/pkg/util/apply" ) type Torch struct{} @@ -77,54 +77,33 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) // TODO (andreyvelich): Add validation to check that TrainJob doesn't have "PET_" envs. // TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other. // Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940 - - infoEnvs := []corev1.EnvVar{ - { - Name: constants.TorchEnvNumNodes, - Value: fmt.Sprintf("%d", ptr.Deref(numNodes, 1)), - }, - { - Name: constants.TorchEnvNumProcPerNode, - Value: numProcPerNode.String(), - }, - { - Name: constants.TorchEnvNodeRank, - ValueFrom: &corev1.EnvVarSource{ - FieldRef: &corev1.ObjectFieldSelector{ - FieldPath: constants.JobCompletionIndexFieldPath, - }, - }, - }, - { - Name: constants.TorchEnvMasterAddr, - Value: fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.JobTrainerNode, trainJob.Name), - }, - { - Name: constants.TorchEnvMasterPort, - Value: fmt.Sprintf("%d", constants.ContainerTrainerPort), - }, - } - - // Set for all Info envs. - envNames := sets.New[string]() - for _, env := range infoEnvs { - envNames.Insert(env.Name) - } - // Info envs take precedence over TrainJob envs. if trainJob.Spec.Trainer != nil { - for _, env := range trainJob.Spec.Trainer.Env { - if !envNames.Has(env.Name) { - info.Trainer.Env = append(info.Trainer.Env, corev1.EnvVar{Name: env.Name, Value: env.Value}) - } - } + info.Trainer.Env = apply.EnvVars(trainJob.Spec.Trainer.Env...) } - // Insert Torch distributed envs into the list end. - info.Trainer.Env = append(info.Trainer.Env, infoEnvs...) + + apply.UpsertEnvVar(&info.Trainer.Env, + corev1ac.EnvVar(). + WithName(constants.TorchEnvNumNodes). + WithValue(fmt.Sprintf("%d", ptr.Deref(numNodes, 1))), + corev1ac.EnvVar(). + WithName(constants.TorchEnvNumProcPerNode). + WithValue(numProcPerNode.String()), + corev1ac.EnvVar(). + WithName(constants.TorchEnvNodeRank). + WithValueFrom(corev1ac.EnvVarSource(). + WithFieldRef(corev1ac.ObjectFieldSelector(). + WithFieldPath(constants.JobCompletionIndexFieldPath))), + corev1ac.EnvVar(). + WithName(constants.TorchEnvMasterAddr). + WithValue(fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.JobTrainerNode, trainJob.Name)), + corev1ac.EnvVar(). + WithName(constants.TorchEnvMasterPort). + WithValue(fmt.Sprintf("%d", constants.ContainerTrainerPort)), + ) // Add container port for the headless service. - info.Trainer.ContainerPort = &corev1.ContainerPort{ - ContainerPort: constants.ContainerTrainerPort, - } + info.Trainer.ContainerPort = corev1ac.ContainerPort(). + WithContainerPort(constants.ContainerTrainerPort) // Update total Pod requests for the PodGroupPolicy plugin. for rName := range info.TotalRequests { diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index d69146fd99..4352a7cb49 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -20,6 +20,7 @@ import ( "maps" corev1 "k8s.io/api/core/v1" + corev1ac "k8s.io/client-go/applyconfigurations/core/v1" kueuelr "sigs.k8s.io/kueue/pkg/util/limitrange" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" @@ -47,10 +48,10 @@ type Trainer struct { NumProcPerNode string // TODO (andreyvelich). Potentially, we can use map for env and sort it to improve code. // Context: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823267183 - Env []corev1.EnvVar - ContainerPort *corev1.ContainerPort - Volumes []corev1.Volume - VolumeMounts []corev1.VolumeMount + Env []corev1ac.EnvVarApplyConfiguration + ContainerPort *corev1ac.ContainerPortApplyConfiguration + Volumes []corev1ac.VolumeApplyConfiguration + VolumeMounts []corev1ac.VolumeMountApplyConfiguration } // TODO (andreyvelich): Potentially, we can add ScheduleTimeoutSeconds to the Scheduler for consistency. diff --git a/pkg/util/apply/apply.go b/pkg/util/apply/apply.go index 4ce49d90d3..83a52a7eea 100644 --- a/pkg/util/apply/apply.go +++ b/pkg/util/apply/apply.go @@ -18,37 +18,66 @@ package apply import ( corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" corev1ac "k8s.io/client-go/applyconfigurations/core/v1" - metav1ac "k8s.io/client-go/applyconfigurations/meta/v1" + "k8s.io/utils/ptr" ) -func ContainerPort(p corev1.ContainerPort) *corev1ac.ContainerPortApplyConfiguration { - port := corev1ac.ContainerPort() - if p.ContainerPort > 0 { - port.WithContainerPort(p.ContainerPort) +func UpsertEnvVar(envVarList *[]corev1ac.EnvVarApplyConfiguration, envVar ...*corev1ac.EnvVarApplyConfiguration) { + for _, e := range envVar { + upsert(envVarList, *e, byEnvVarName) } - if p.HostPort > 0 { - port.WithHostPort(p.HostPort) +} + +func UpsertEnvVars(envVarList *[]corev1ac.EnvVarApplyConfiguration, envVars []corev1ac.EnvVarApplyConfiguration) { + for _, e := range envVars { + upsert(envVarList, e, byEnvVarName) } - if p.HostIP != "" { - port.WithHostIP(p.HostIP) +} + +func UpsertPort(portList *[]corev1ac.ContainerPortApplyConfiguration, port ...*corev1ac.ContainerPortApplyConfiguration) { + for _, p := range port { + upsert(portList, *p, byContainerPortOrName) } - if p.Name != "" { - port.WithName(p.Name) +} + +func UpsertVolumes(volumeList *[]corev1ac.VolumeApplyConfiguration, volumes []corev1ac.VolumeApplyConfiguration) { + for _, v := range volumes { + upsert(volumeList, v, byVolumeName) } - if p.Protocol != "" { - port.WithProtocol(p.Protocol) +} + +func UpsertVolumeMounts(mountList *[]corev1ac.VolumeMountApplyConfiguration, mounts []corev1ac.VolumeMountApplyConfiguration) { + for _, m := range mounts { + upsert(mountList, m, byVolumeMountName) } - return port } -func ContainerPorts(p ...corev1.ContainerPort) []*corev1ac.ContainerPortApplyConfiguration { - var ports []*corev1ac.ContainerPortApplyConfiguration - for _, port := range p { - ports = append(ports, ContainerPort(port)) +func byEnvVarName(a, b corev1ac.EnvVarApplyConfiguration) bool { + return ptr.Equal(a.Name, b.Name) +} + +func byContainerPortOrName(a, b corev1ac.ContainerPortApplyConfiguration) bool { + return ptr.Equal(a.ContainerPort, b.ContainerPort) || ptr.Equal(a.Name, b.Name) +} + +func byVolumeName(a, b corev1ac.VolumeApplyConfiguration) bool { + return ptr.Equal(a.Name, b.Name) +} + +func byVolumeMountName(a, b corev1ac.VolumeMountApplyConfiguration) bool { + return ptr.Equal(a.Name, b.Name) +} + +type compare[T any] func(T, T) bool + +func upsert[T any](items *[]T, item T, predicate compare[T]) { + for i, t := range *items { + if predicate(t, item) { + (*items)[i] = item + return + } } - return ports + *items = append(*items, item) } func EnvVar(e corev1.EnvVar) *corev1ac.EnvVarApplyConfiguration { @@ -85,91 +114,10 @@ func EnvVar(e corev1.EnvVar) *corev1ac.EnvVarApplyConfiguration { return envVar } -func EnvVars(e ...corev1.EnvVar) []*corev1ac.EnvVarApplyConfiguration { - var envs []*corev1ac.EnvVarApplyConfiguration - for _, env := range e { - envs = append(envs, EnvVar(env)) - } - return envs -} - -func EnvFromSource(e corev1.EnvFromSource) *corev1ac.EnvFromSourceApplyConfiguration { - envVarFrom := corev1ac.EnvFromSource() - if e.Prefix != "" { - envVarFrom.WithPrefix(e.Prefix) - } - if ref := e.ConfigMapRef; ref != nil { - source := corev1ac.ConfigMapEnvSource().WithName(ref.Name) - if ref.Optional != nil { - source.WithOptional(*ref.Optional) - } - envVarFrom.WithConfigMapRef(source) - } - if ref := e.SecretRef; ref != nil { - source := corev1ac.SecretEnvSource().WithName(ref.Name) - if ref.Optional != nil { - source.WithOptional(*ref.Optional) - } - envVarFrom.WithSecretRef(source) - } - return envVarFrom -} - -func EnvFromSources(e ...corev1.EnvFromSource) []*corev1ac.EnvFromSourceApplyConfiguration { - var envs []*corev1ac.EnvFromSourceApplyConfiguration +func EnvVars(e ...corev1.EnvVar) []corev1ac.EnvVarApplyConfiguration { + var envs []corev1ac.EnvVarApplyConfiguration for _, env := range e { - envs = append(envs, EnvFromSource(env)) + envs = append(envs, *EnvVar(env)) } return envs } - -func Condition(c metav1.Condition) *metav1ac.ConditionApplyConfiguration { - condition := metav1ac.Condition(). - WithObservedGeneration(c.ObservedGeneration) - if c.Type != "" { - condition.WithType(c.Type) - } - if c.Message != "" { - condition.WithMessage(c.Message) - } - if c.Reason != "" { - condition.WithReason(c.Reason) - } - if c.Status != "" { - condition.WithStatus(c.Status) - } - if !c.LastTransitionTime.IsZero() { - condition.WithLastTransitionTime(c.LastTransitionTime) - } - return condition -} - -func Conditions(c ...metav1.Condition) []*metav1ac.ConditionApplyConfiguration { - var conditions []*metav1ac.ConditionApplyConfiguration - for _, condition := range c { - conditions = append(conditions, Condition(condition)) - } - return conditions -} - -func Volume(v corev1.Volume) *corev1ac.VolumeApplyConfiguration { - volume := corev1ac.Volume().WithName(v.Name) - // FIXME - return volume -} - -func VolumeMount(m corev1.VolumeMount) *corev1ac.VolumeMountApplyConfiguration { - volumeMount := corev1ac.VolumeMount().WithName(m.Name) - if m.MountPath != "" { - volumeMount.WithMountPath(m.MountPath) - } - return volumeMount -} - -func VolumeMounts(v ...corev1.VolumeMount) []*corev1ac.VolumeMountApplyConfiguration { - var mounts []*corev1ac.VolumeMountApplyConfiguration - for _, mount := range v { - mounts = append(mounts, VolumeMount(mount)) - } - return mounts -} From 5df095d7c6da2304d8d7dcdf38c894d22125f6d7 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Mon, 17 Feb 2025 17:49:10 +0100 Subject: [PATCH 14/17] Do not update JobSets that are not suspended Signed-off-by: Antonin Stefanutti --- pkg/runtime/framework/plugins/jobset/jobset.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index 0406b1c9f7..dfe6e2716b 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -23,12 +23,14 @@ import ( "maps" "github.com/go-logr/logr" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" metav1ac "k8s.io/client-go/applyconfigurations/meta/v1" + "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" @@ -90,6 +92,20 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine return nil, fmt.Errorf("runtime info or object is missing") } + // Do not update the JobSet if it already exists and is not suspended + oldJobSet := &jobsetv1alpha2.JobSet{} + if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), oldJobSet); err != nil { + if !apierrors.IsNotFound(err) { + return nil, err + } + oldJobSet = nil + } + if oldJobSet != nil && + !ptr.Deref(trainJob.Spec.Suspend, false) && + !ptr.Deref(oldJobSet.Spec.Suspend, false) { + return nil, nil + } + // Get the runtime as unstructured from the TrainJob ref runtimeJobTemplate := &unstructured.Unstructured{} runtimeJobTemplate.SetAPIVersion(trainer.GroupVersion.String()) From 5573726ba87d0d22691a6e5253ae333143c20a19 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Tue, 25 Feb 2025 19:06:48 +0100 Subject: [PATCH 15/17] Address review feedback Signed-off-by: Antonin Stefanutti --- pkg/{util => }/apply/apply.go | 26 +++++++++---------- pkg/controller/trainjob_controller.go | 16 ++++++------ .../core/clustertrainingruntime_test.go | 2 +- pkg/runtime/core/trainingruntime_test.go | 2 +- pkg/runtime/framework/core/framework.go | 6 ++--- pkg/runtime/framework/core/framework_test.go | 9 +++---- .../framework/plugins/jobset/builder.go | 2 +- pkg/runtime/framework/plugins/mpi/mpi.go | 2 +- .../framework/plugins/plainml/plainml.go | 2 +- pkg/runtime/framework/plugins/torch/torch.go | 2 +- pkg/util/testing/runtime.go | 18 ++++++++----- pkg/webhooks/trainingruntime_webhook_test.go | 24 ++++++----------- 12 files changed, 53 insertions(+), 58 deletions(-) rename pkg/{util => }/apply/apply.go (76%) diff --git a/pkg/util/apply/apply.go b/pkg/apply/apply.go similarity index 76% rename from pkg/util/apply/apply.go rename to pkg/apply/apply.go index 83a52a7eea..d995668736 100644 --- a/pkg/util/apply/apply.go +++ b/pkg/apply/apply.go @@ -22,33 +22,33 @@ import ( "k8s.io/utils/ptr" ) -func UpsertEnvVar(envVarList *[]corev1ac.EnvVarApplyConfiguration, envVar ...*corev1ac.EnvVarApplyConfiguration) { +func UpsertEnvVar(envVars *[]corev1ac.EnvVarApplyConfiguration, envVar ...*corev1ac.EnvVarApplyConfiguration) { for _, e := range envVar { - upsert(envVarList, *e, byEnvVarName) + upsert(envVars, *e, byEnvVarName) } } -func UpsertEnvVars(envVarList *[]corev1ac.EnvVarApplyConfiguration, envVars []corev1ac.EnvVarApplyConfiguration) { - for _, e := range envVars { - upsert(envVarList, e, byEnvVarName) +func UpsertEnvVars(envVars *[]corev1ac.EnvVarApplyConfiguration, upEnvVars []corev1ac.EnvVarApplyConfiguration) { + for _, e := range upEnvVars { + upsert(envVars, e, byEnvVarName) } } -func UpsertPort(portList *[]corev1ac.ContainerPortApplyConfiguration, port ...*corev1ac.ContainerPortApplyConfiguration) { +func UpsertPort(ports *[]corev1ac.ContainerPortApplyConfiguration, port ...*corev1ac.ContainerPortApplyConfiguration) { for _, p := range port { - upsert(portList, *p, byContainerPortOrName) + upsert(ports, *p, byContainerPortOrName) } } -func UpsertVolumes(volumeList *[]corev1ac.VolumeApplyConfiguration, volumes []corev1ac.VolumeApplyConfiguration) { - for _, v := range volumes { - upsert(volumeList, v, byVolumeName) +func UpsertVolumes(volumes *[]corev1ac.VolumeApplyConfiguration, upVolumes []corev1ac.VolumeApplyConfiguration) { + for _, v := range upVolumes { + upsert(volumes, v, byVolumeName) } } -func UpsertVolumeMounts(mountList *[]corev1ac.VolumeMountApplyConfiguration, mounts []corev1ac.VolumeMountApplyConfiguration) { - for _, m := range mounts { - upsert(mountList, m, byVolumeMountName) +func UpsertVolumeMounts(mounts *[]corev1ac.VolumeMountApplyConfiguration, upMounts []corev1ac.VolumeMountApplyConfiguration) { + for _, m := range upMounts { + upsert(mounts, m, byVolumeMountName) } } diff --git a/pkg/controller/trainjob_controller.go b/pkg/controller/trainjob_controller.go index 6ea57a5a18..a4c070fb0e 100644 --- a/pkg/controller/trainjob_controller.go +++ b/pkg/controller/trainjob_controller.go @@ -27,7 +27,7 @@ import ( "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" - k8sruntime "k8s.io/apimachinery/pkg/runtime" + apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/tools/record" "k8s.io/klog/v2" @@ -117,14 +117,14 @@ func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobru // client. See https://github.com/kubernetes/kubernetes/pull/129313 var obj client.Object if o, ok := object.(client.Object); ok { - obj = o - } else { - if u, err := k8sruntime.DefaultUnstructuredConverter.ToUnstructured(object); err != nil { - return buildFailed, err - } else { - obj = &unstructured.Unstructured{Object: u} - } + return buildFailed, fmt.Errorf("unsupported type client.Object for component: %v", o) + } + + u, err := apiruntime.DefaultUnstructuredConverter.ToUnstructured(object) + if err != nil { + return buildFailed, err } + obj = &unstructured.Unstructured{Object: u} if err := r.client.Patch(ctx, obj, client.Apply, client.FieldOwner("trainer"), client.ForceOwnership); err != nil { return buildFailed, err diff --git a/pkg/runtime/core/clustertrainingruntime_test.go b/pkg/runtime/core/clustertrainingruntime_test.go index f63b9443db..447905c7c7 100644 --- a/pkg/runtime/core/clustertrainingruntime_test.go +++ b/pkg/runtime/core/clustertrainingruntime_test.go @@ -133,7 +133,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { resultObjs, err := testingutil.ToObject(c.Scheme(), objs...) if err != nil { - t.Fatal(err) + t.Errorf("Pipeline built unrecognizable objects: %v", err) } if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 { diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index afda250fdf..023114f31e 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -449,7 +449,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { resultObjs, err := testingutil.ToObject(c.Scheme(), objs...) if err != nil { - t.Fatal(err) + t.Errorf("Pipeline built unrecognizable objects: %v", err) } if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 { diff --git a/pkg/runtime/framework/core/framework.go b/pkg/runtime/framework/core/framework.go index 476c2d333a..264a148d6e 100644 --- a/pkg/runtime/framework/core/framework.go +++ b/pkg/runtime/framework/core/framework.go @@ -115,11 +115,11 @@ func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *trainer.TrainJob) func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) { var objs []any for _, plugin := range f.componentBuilderPlugins { - if components, err := plugin.Build(ctx, info, trainJob); err != nil { + components, err := plugin.Build(ctx, info, trainJob) + if err != nil { return nil, err - } else if components != nil { - objs = append(objs, components...) } + objs = append(objs, components...) } return objs, nil } diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index d0f7c4812c..c1f96d90a1 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -27,7 +27,6 @@ import ( "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" apiruntime "k8s.io/apimachinery/pkg/runtime" - k8sruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" @@ -395,7 +394,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { trainingRuntime *trainer.TrainingRuntime trainJob *trainer.TrainJob wantRuntimeInfo *runtime.Info - wantObjs []k8sruntime.Object + wantObjs []apiruntime.Object wantError error }{ "succeeded to build PodGroup and JobSet with NumNodes from TrainJob": { @@ -470,7 +469,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { }, }, }, - wantObjs: []k8sruntime.Object{ + wantObjs: []apiruntime.Object{ testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job"). SchedulingTimeout(300). MinMember(101). // 101 replicas = 100 Trainer nodes + 1 Initializer. @@ -491,7 +490,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { "an empty registry": {}, } cmpOpts := []cmp.Option{ - cmpopts.SortSlices(func(a, b k8sruntime.Object) bool { + cmpopts.SortSlices(func(a, b apiruntime.Object) bool { return a.GetObjectKind().GroupVersionKind().String() < b.GetObjectKind().GroupVersionKind().String() }), cmpopts.EquateEmpty(), @@ -530,7 +529,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { resultObjs, err := testingutil.ToObject(c.Scheme(), objs...) if err != nil { - t.Fatal(err) + t.Errorf("Pipeline built unrecognizable objects: %v", err) } if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 { diff --git a/pkg/runtime/framework/plugins/jobset/builder.go b/pkg/runtime/framework/plugins/jobset/builder.go index c59326e35e..d575b130a1 100644 --- a/pkg/runtime/framework/plugins/jobset/builder.go +++ b/pkg/runtime/framework/plugins/jobset/builder.go @@ -21,9 +21,9 @@ import ( jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/pkg/apply" "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" - "github.com/kubeflow/trainer/pkg/util/apply" ) type Builder struct { diff --git a/pkg/runtime/framework/plugins/mpi/mpi.go b/pkg/runtime/framework/plugins/mpi/mpi.go index ec81da74db..34228814fe 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi.go +++ b/pkg/runtime/framework/plugins/mpi/mpi.go @@ -40,10 +40,10 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook/admission" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/pkg/apply" "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" "github.com/kubeflow/trainer/pkg/runtime/framework" - "github.com/kubeflow/trainer/pkg/util/apply" ) type MPI struct { diff --git a/pkg/runtime/framework/plugins/plainml/plainml.go b/pkg/runtime/framework/plugins/plainml/plainml.go index b389b97d58..4fbb4300ac 100644 --- a/pkg/runtime/framework/plugins/plainml/plainml.go +++ b/pkg/runtime/framework/plugins/plainml/plainml.go @@ -23,10 +23,10 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/pkg/apply" "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" "github.com/kubeflow/trainer/pkg/runtime/framework" - "github.com/kubeflow/trainer/pkg/util/apply" ) var _ framework.EnforceMLPolicyPlugin = (*PlainML)(nil) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index bf715dd215..88d4fa5b06 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -28,10 +28,10 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook/admission" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/pkg/apply" "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" "github.com/kubeflow/trainer/pkg/runtime/framework" - "github.com/kubeflow/trainer/pkg/util/apply" ) type Torch struct{} diff --git a/pkg/util/testing/runtime.go b/pkg/util/testing/runtime.go index 1cf4abd7d0..1c85c410a9 100644 --- a/pkg/util/testing/runtime.go +++ b/pkg/util/testing/runtime.go @@ -1,5 +1,5 @@ /* -Copyright 2024 The Kubeflow Authors. +Copyright 2025 The Kubeflow Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,11 +27,11 @@ import ( func ToObject(s *runtime.Scheme, objects ...any) ([]runtime.Object, error) { var objs []runtime.Object for _, obj := range objects { - if o, err := toObject(s, obj); err != nil { + o, err := toObject(s, obj) + if err != nil { return nil, err - } else { - objs = append(objs, o) } + objs = append(objs, o) } return objs, nil } @@ -40,12 +40,14 @@ func toObject(s *runtime.Scheme, obj any) (runtime.Object, error) { if o, ok := obj.(runtime.Object); ok { return o, nil } + var u *unstructured.Unstructured - if o, err := runtime.DefaultUnstructuredConverter.ToUnstructured(obj); err != nil { + o, err := runtime.DefaultUnstructuredConverter.ToUnstructured(obj) + if err != nil { return nil, err - } else { - u = &unstructured.Unstructured{Object: o} } + u = &unstructured.Unstructured{Object: o} + gvk := u.GetObjectKind().GroupVersionKind() if !s.Recognizes(gvk) { return nil, fmt.Errorf("%s is not a recognized schema", gvk.GroupVersion().String()) @@ -54,6 +56,7 @@ func toObject(s *runtime.Scheme, obj any) (runtime.Object, error) { if err != nil { return nil, fmt.Errorf("scheme recognizes %s but failed to produce an object for it: %w", gvk, err) } + raw, err := json.Marshal(u) if err != nil { return nil, fmt.Errorf("failed to serialize %T: %w", raw, err) @@ -61,5 +64,6 @@ func toObject(s *runtime.Scheme, obj any) (runtime.Object, error) { if err := json.Unmarshal(raw, typed); err != nil { return nil, fmt.Errorf("failed to unmarshal the content of %T into %T: %w", u, typed, err) } + return typed, nil } diff --git a/pkg/webhooks/trainingruntime_webhook_test.go b/pkg/webhooks/trainingruntime_webhook_test.go index 04710b5225..9423e46d48 100644 --- a/pkg/webhooks/trainingruntime_webhook_test.go +++ b/pkg/webhooks/trainingruntime_webhook_test.go @@ -23,6 +23,8 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "k8s.io/apimachinery/pkg/util/validation/field" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + + testingutil "github.com/kubeflow/trainer/pkg/util/testing" ) func TestValidateReplicatedJobs(t *testing.T) { @@ -31,24 +33,14 @@ func TestValidateReplicatedJobs(t *testing.T) { wantError field.ErrorList }{ "valid replicatedJobs": { - rJobs: []jobsetv1alpha2.ReplicatedJob{ - { - Replicas: 1, - }, - { - Replicas: 1, - }, - }, + rJobs: testingutil.MakeJobSetWrapper("ns", "valid"). + Replicas(1). + Obj().Spec.ReplicatedJobs, }, "invalid replicas": { - rJobs: []jobsetv1alpha2.ReplicatedJob{ - { - Replicas: 2, - }, - { - Replicas: 2, - }, - }, + rJobs: testingutil.MakeJobSetWrapper("ns", "valid"). + Replicas(2). + Obj().Spec.ReplicatedJobs, wantError: field.ErrorList{ field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(0).Child("replicas"), "2", ""), From 441703714edfaf2a1deabf7dc12dcfd468ee4bd2 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Tue, 25 Feb 2025 21:46:53 +0100 Subject: [PATCH 16/17] Do not update PodGroup if TrainJob is not suspended Signed-off-by: Antonin Stefanutti --- .../plugins/coscheduling/coscheduling.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pkg/runtime/framework/plugins/coscheduling/coscheduling.go b/pkg/runtime/framework/plugins/coscheduling/coscheduling.go index ada4ac70e6..8678c12b80 100644 --- a/pkg/runtime/framework/plugins/coscheduling/coscheduling.go +++ b/pkg/runtime/framework/plugins/coscheduling/coscheduling.go @@ -25,6 +25,7 @@ import ( "github.com/go-logr/logr" corev1 "k8s.io/api/core/v1" nodev1 "k8s.io/api/node/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -101,11 +102,23 @@ func (c *CoScheduling) EnforcePodGroupPolicy(info *runtime.Info, trainJob *train return nil } -func (c *CoScheduling) Build(_ context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) { +func (c *CoScheduling) Build(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) { if info == nil || info.RuntimePolicy.PodGroupPolicy == nil || info.RuntimePolicy.PodGroupPolicy.Coscheduling == nil || trainJob == nil { return nil, nil } + // Do not update the PodGroup if it already exists and the TrainJob is not suspended + oldPodGroup := &schedulerpluginsv1alpha1.PodGroup{} + if err := c.client.Get(ctx, client.ObjectKeyFromObject(trainJob), oldPodGroup); err != nil { + if !apierrors.IsNotFound(err) { + return nil, err + } + oldPodGroup = nil + } + if oldPodGroup != nil && !ptr.Deref(trainJob.Spec.Suspend, false) { + return nil, nil + } + var totalMembers int32 totalResources := make(corev1.ResourceList) for _, resourceRequests := range info.TotalRequests { From 09d467251599bab888e0769238b673d103547e1c Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Wed, 26 Feb 2025 08:52:43 +0100 Subject: [PATCH 17/17] Remove obsolete TODO Signed-off-by: Antonin Stefanutti --- pkg/runtime/runtime.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 4352a7cb49..8216be8abb 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -46,12 +46,10 @@ type RuntimePolicy struct { type Trainer struct { NumNodes *int32 NumProcPerNode string - // TODO (andreyvelich). Potentially, we can use map for env and sort it to improve code. - // Context: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823267183 - Env []corev1ac.EnvVarApplyConfiguration - ContainerPort *corev1ac.ContainerPortApplyConfiguration - Volumes []corev1ac.VolumeApplyConfiguration - VolumeMounts []corev1ac.VolumeMountApplyConfiguration + Env []corev1ac.EnvVarApplyConfiguration + ContainerPort *corev1ac.ContainerPortApplyConfiguration + Volumes []corev1ac.VolumeApplyConfiguration + VolumeMounts []corev1ac.VolumeMountApplyConfiguration } // TODO (andreyvelich): Potentially, we can add ScheduleTimeoutSeconds to the Scheduler for consistency.