diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index fcf6377e32..8716b7a4c8 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -72,6 +72,9 @@ const ( // TorchEnvMasterPort is the env name for the master node port. TorchEnvMasterPort string = "PET_MASTER_PORT" + // TorchEnvNamePrefix is the env name prefix for the distributed envs for torchrun. + TorchEnvNamePrefix = "PET_" + // JobLauncher is the Job name for the launcher. JobLauncher string = "launcher" @@ -111,6 +114,8 @@ const ( // Distributed envs for mpirun. // Values for OpenMPI implementation. OpenMPIEnvHostFileLocation string = "OMPI_MCA_orte_default_hostfile" + + UnsupportedRuntimeErrMsg string = "the specified runtime is not supported" ) var ( diff --git a/pkg/controller/trainjob_controller.go b/pkg/controller/trainjob_controller.go index addb7675eb..1962d31e10 100644 --- a/pkg/controller/trainjob_controller.go +++ b/pkg/controller/trainjob_controller.go @@ -39,8 +39,6 @@ import ( jobruntimes "github.com/kubeflow/trainer/pkg/runtime" ) -var errorUnsupportedRuntime = errors.New("the specified runtime is not supported") - type objsOpState int const ( @@ -83,10 +81,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c return ctrl.Result{}, nil } - runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String() + runtimeRefGK := jobruntimes.RuntimeRefToRuntimeRegistryKey(trainJob.Spec.RuntimeRef) runtime, ok := r.runtimes[runtimeRefGK] if !ok { - return ctrl.Result{}, fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK) + return ctrl.Result{}, fmt.Errorf("%s: %s", constants.UnsupportedRuntimeErrMsg, runtimeRefGK) } opState, err := r.reconcileObjects(ctx, runtime, &trainJob) @@ -214,13 +212,6 @@ func isTrainJobFinished(trainJob *trainer.TrainJob) bool { meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed) } -func runtimeRefToGroupKind(runtimeRef trainer.RuntimeRef) schema.GroupKind { - return schema.GroupKind{ - Group: ptr.Deref(runtimeRef.APIGroup, ""), - Kind: ptr.Deref(runtimeRef.Kind, ""), - } -} - func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error { b := ctrl.NewControllerManagedBy(mgr). WithOptions(options). diff --git a/pkg/runtime/core/clustertrainingruntime.go b/pkg/runtime/core/clustertrainingruntime.go index 6fe0be1501..f8c448fd01 100644 --- a/pkg/runtime/core/clustertrainingruntime.go +++ b/pkg/runtime/core/clustertrainingruntime.go @@ -26,6 +26,7 @@ import ( "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/pkg/runtime" @@ -69,14 +70,19 @@ func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBu } func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) { + clusterTrainingRuntime := &trainer.ClusterTrainingRuntime{} if err := r.client.Get(ctx, client.ObjectKey{ Namespace: old.Namespace, Name: old.Spec.RuntimeRef.Name, }, &trainer.ClusterTrainingRuntime{}); err != nil { return nil, field.ErrorList{ - field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef, + field.Invalid(field.NewPath("spec", "RuntimeRef"), new.Spec.RuntimeRef, fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)), } } - return r.framework.RunCustomValidationPlugins(old, new) + info := r.runtimeInfo(ctx, new, clusterTrainingRuntime.Spec.Template, clusterTrainingRuntime.Spec.MLPolicy, clusterTrainingRuntime.Spec.PodGroupPolicy) + jobSetTemplate := jobsetv1alpha2.JobSet{ + Spec: clusterTrainingRuntime.Spec.Template.Spec, + } + return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new) } diff --git a/pkg/runtime/core/registry.go b/pkg/runtime/core/registry.go index 61cf6bef29..4b1cf94bbe 100644 --- a/pkg/runtime/core/registry.go +++ b/pkg/runtime/core/registry.go @@ -18,7 +18,6 @@ package core import ( "context" - "sigs.k8s.io/controller-runtime/pkg/client" "github.com/kubeflow/trainer/pkg/runtime" diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index 7736a52b24..6704138509 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -83,6 +83,26 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.Trai func (r *TrainingRuntime) buildObjects( ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy, ) ([]client.Object, error) { + + info := r.runtimeInfo(ctx, trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy) + if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil { + return nil, err + } + + if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil { + return nil, err + } + + jobSetTemplate := jobsetv1alpha2.JobSet{ + Spec: jobSetTemplateSpec.Spec, + } + + return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob) +} + +func (r *TrainingRuntime) runtimeInfo( + ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy) *runtime.Info { + propagationLabels := jobSetTemplateSpec.Labels if propagationLabels == nil && trainJob.Spec.Labels != nil { propagationLabels = make(map[string]string, len(trainJob.Spec.Labels)) @@ -113,19 +133,7 @@ func (r *TrainingRuntime) buildObjects( info := runtime.NewInfo(opts...) - if err := r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil { - return nil, err - } - - if err := r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil { - return nil, err - } - - jobSetTemplate := jobsetv1alpha2.JobSet{ - Spec: jobSetTemplateSpec.Spec, - } - - return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob) + return info } func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { @@ -141,14 +149,19 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { } func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) { + trainingRuntime := &trainer.TrainingRuntime{} if err := r.client.Get(ctx, client.ObjectKey{ - Namespace: old.Namespace, - Name: old.Spec.RuntimeRef.Name, - }, &trainer.TrainingRuntime{}); err != nil { + Namespace: new.Namespace, + Name: new.Spec.RuntimeRef.Name, + }, trainingRuntime); err != nil { return nil, field.ErrorList{ - field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef, + field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef, fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)), } } - return r.framework.RunCustomValidationPlugins(old, new) + info := r.runtimeInfo(ctx, new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy) + jobSetTemplate := jobsetv1alpha2.JobSet{ + Spec: trainingRuntime.Spec.Template.Spec, + } + return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new) } diff --git a/pkg/runtime/framework/core/framework.go b/pkg/runtime/framework/core/framework.go index 80e3bb4b60..362518545b 100644 --- a/pkg/runtime/framework/core/framework.go +++ b/pkg/runtime/framework/core/framework.go @@ -97,11 +97,11 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(info *runtime.Info, trainJob return nil } -func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { +func (f *Framework) RunCustomValidationPlugins(runtimeJobTemplate client.Object, info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { var aggregatedWarnings admission.Warnings var aggregatedErrors field.ErrorList for _, plugin := range f.customValidationPlugins { - warnings, errs := plugin.Validate(oldObj, newObj) + warnings, errs := plugin.Validate(runtimeJobTemplate, info, oldObj, newObj) if len(warnings) != 0 { aggregatedWarnings = append(aggregatedWarnings, warnings...) } diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index 0712404d62..a16daebedd 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -82,6 +82,7 @@ func TestNew(t *testing.T) { customValidationPlugins: []framework.CustomValidationPlugin{ &mpi.MPI{}, &torch.Torch{}, + &jobset.JobSet{}, }, watchExtensionPlugins: []framework.WatchExtensionPlugin{ &coscheduling.CoScheduling{}, @@ -371,7 +372,9 @@ func TestRunCustomValidationPlugins(t *testing.T) { if err != nil { t.Fatal(err) } - warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj) + runtimeInfo := runtime.NewInfo() + jobSetTemplate := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test") + warnings, errs := fwk.RunCustomValidationPlugins(jobSetTemplate, runtimeInfo, tc.oldObj, tc.newObj) if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 { t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff) } diff --git a/pkg/runtime/framework/interface.go b/pkg/runtime/framework/interface.go index 9f208dd174..cf6712247f 100644 --- a/pkg/runtime/framework/interface.go +++ b/pkg/runtime/framework/interface.go @@ -34,7 +34,7 @@ type Plugin interface { type CustomValidationPlugin interface { Plugin - Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) + Validate(runtimeJobTemplate client.Object, info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) } type WatchExtensionPlugin interface { diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index e900fa4031..2a4117ced1 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -20,20 +20,24 @@ import ( "context" "fmt" "maps" + "slices" "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" @@ -52,6 +56,7 @@ type JobSet struct { var _ framework.WatchExtensionPlugin = (*JobSet)(nil) var _ framework.ComponentBuilderPlugin = (*JobSet)(nil) var _ framework.TerminalConditionPlugin = (*JobSet)(nil) +var _ framework.CustomValidationPlugin = (*JobSet)(nil) const Name = constants.JobSetKind @@ -159,3 +164,52 @@ func (j *JobSet) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJ } return nil, nil } + +func (j *JobSet) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { + + var allErrs field.ErrorList + specPath := field.NewPath("spec") + runtimeRefPath := specPath.Child("runtimeRef") + + jobSet, ok := runtimeJobTemplate.(*jobsetv1alpha2.JobSet) + if !ok { + return nil, nil + } + + if newObj.Spec.ModelConfig != nil && newObj.Spec.ModelConfig.Input != nil { + if !slices.ContainsFunc(jobSet.Spec.ReplicatedJobs, func(x jobsetv1alpha2.ReplicatedJob) bool { + return x.Name == constants.JobInitializer + }) { + allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have %s job when trainJob is configured with input modelConfig", constants.JobInitializer))) + } else { + for _, job := range jobSet.Spec.ReplicatedJobs { + if job.Name == constants.JobInitializer { + if !slices.ContainsFunc(job.Template.Spec.Template.Spec.InitContainers, func(x corev1.Container) bool { + return x.Name == constants.ContainerModelInitializer + }) { + allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have container with name - %s in the %s job", constants.ContainerModelInitializer, constants.JobInitializer))) + } + } + } + } + } + + if newObj.Spec.DatasetConfig != nil { + if !slices.ContainsFunc(jobSet.Spec.ReplicatedJobs, func(x jobsetv1alpha2.ReplicatedJob) bool { + return x.Name == constants.JobInitializer + }) { + allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have %s job when trainJob is configured with input datasetConfig", constants.JobInitializer))) + } else { + for _, job := range jobSet.Spec.ReplicatedJobs { + if job.Name == constants.JobInitializer { + if !slices.ContainsFunc(job.Template.Spec.Template.Spec.InitContainers, func(x corev1.Container) bool { + return x.Name == constants.ContainerDatasetInitializer + }) { + allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("trainingRuntime should have container with name - %s in the %s job", constants.ContainerDatasetInitializer, constants.JobInitializer))) + } + } + } + } + } + return nil, allErrs +} diff --git a/pkg/runtime/framework/plugins/mpi/mpi.go b/pkg/runtime/framework/plugins/mpi/mpi.go index d26c58c7c4..7efb494b5c 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi.go +++ b/pkg/runtime/framework/plugins/mpi/mpi.go @@ -25,6 +25,7 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "k8s.io/apimachinery/pkg/util/intstr" "maps" "strconv" @@ -75,9 +76,19 @@ func (m *MPI) Name() string { return Name } -// TODO: Need to implement validations for MPI Policy. -func (m *MPI) Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { - return nil, nil +func (m *MPI) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldJobObj, newJobObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { + var allErrs field.ErrorList + specPath := field.NewPath("spec") + if newJobObj.Spec.Trainer != nil { + numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode") + if runtimeInfo.RuntimePolicy.MLPolicy != nil && runtimeInfo.RuntimePolicy.MLPolicy.MPI != nil { + numProcPerNode := *newJobObj.Spec.Trainer.NumProcPerNode + if numProcPerNode.Type != intstr.Int { + allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "should have an int value")) + } + } + } + return nil, allErrs } func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error { diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index af9a04c456..2b3725fb99 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -19,6 +19,8 @@ package torch import ( "context" "fmt" + "slices" + "strings" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/intstr" @@ -49,11 +51,6 @@ func (t *Torch) Name() string { return Name } -// TODO: Need to implement validations for Torch policy. -func (t *Torch) Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { - return nil, nil -} - // TODO (andreyvelich): Add support for PyTorch elastic when JobSet supports Elastic Jobs. func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error { if info == nil || info.RuntimePolicy.MLPolicy == nil || info.RuntimePolicy.MLPolicy.Torch == nil { @@ -140,3 +137,31 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) return nil } + +func (t *Torch) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { + var allErrs field.ErrorList + specPath := field.NewPath("spec") + + if newObj.Spec.Trainer != nil { + numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode") + if runtimeInfo.RuntimePolicy.MLPolicy != nil && + runtimeInfo.RuntimePolicy.MLPolicy.Torch != nil && newObj.Spec.Trainer.NumProcPerNode != nil { + numProcPerNode := *newObj.Spec.Trainer.NumProcPerNode + if numProcPerNode.Type == intstr.String { + allowedStringValList := []string{"auto", "cpu", "gpu"} + if !slices.Contains(allowedStringValList, numProcPerNode.StrVal) { + allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newObj.Spec.Trainer.NumProcPerNode, "should have an int value or auto/cpu/gpu")) + } + } + } + + if slices.ContainsFunc(newObj.Spec.Trainer.Env, func(x corev1.EnvVar) bool { + return strings.HasPrefix(x.Name, constants.TorchEnvNamePrefix) + }) { + trainerEnvsPath := specPath.Child("trainer").Child("env") + allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("should not have envs with name having prefix %s", constants.TorchEnvNamePrefix))) + } + } + + return nil, allErrs +} diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index d69146fd99..8dc4d68e09 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -20,6 +20,8 @@ import ( "maps" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/utils/ptr" kueuelr "sigs.k8s.io/kueue/pkg/util/limitrange" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" @@ -146,3 +148,10 @@ func NewInfo(opts ...InfoOption) *Info { return info } + +func RuntimeRefToRuntimeRegistryKey(runtimeRef trainer.RuntimeRef) string { + return schema.GroupKind{ + Group: ptr.Deref(runtimeRef.APIGroup, ""), + Kind: ptr.Deref(runtimeRef.Kind, ""), + }.String() +} diff --git a/pkg/webhooks/clustertrainingruntime_webhook.go b/pkg/webhooks/clustertrainingruntime_webhook.go index bf3696099d..8e31f4f4e7 100644 --- a/pkg/webhooks/clustertrainingruntime_webhook.go +++ b/pkg/webhooks/clustertrainingruntime_webhook.go @@ -51,8 +51,11 @@ func (w *ClusterTrainingRuntimeWebhook) ValidateCreate(ctx context.Context, obj return nil, validateReplicatedJobs(clTrainingRuntime.Spec.Template.Spec.ReplicatedJobs).ToAggregate() } -func (w *ClusterTrainingRuntimeWebhook) ValidateUpdate(context.Context, apiruntime.Object, apiruntime.Object) (admission.Warnings, error) { - return nil, nil +func (w *ClusterTrainingRuntimeWebhook) ValidateUpdate(ctx context.Context, oldObj apiruntime.Object, newObj apiruntime.Object) (admission.Warnings, error) { + clTrainingRuntimeNew := newObj.(*trainer.ClusterTrainingRuntime) + log := ctrl.LoggerFrom(ctx).WithName("clustertrainingruntime-webhook") + log.V(5).Info("Validating update", "clusterTrainingRuntime", klog.KObj(clTrainingRuntimeNew)) + return nil, validateReplicatedJobs(clTrainingRuntimeNew.Spec.Template.Spec.ReplicatedJobs).ToAggregate() } func (w *ClusterTrainingRuntimeWebhook) ValidateDelete(context.Context, apiruntime.Object) (admission.Warnings, error) { diff --git a/pkg/webhooks/setup.go b/pkg/webhooks/setup.go index 6871bc6be7..9b7678e7ca 100644 --- a/pkg/webhooks/setup.go +++ b/pkg/webhooks/setup.go @@ -31,7 +31,7 @@ func Setup(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) (string, error return trainer.TrainingRuntimeKind, err } if err := setupWebhookForTrainJob(mgr, runtimes); err != nil { - return "TrainJob", err + return trainer.TrainJobKind, err } return "", nil } diff --git a/pkg/webhooks/trainjob_webhook.go b/pkg/webhooks/trainjob_webhook.go index 49e2b74f73..5f3c3f6bfb 100644 --- a/pkg/webhooks/trainjob_webhook.go +++ b/pkg/webhooks/trainjob_webhook.go @@ -18,13 +18,16 @@ package webhooks import ( "context" + "fmt" apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" ) @@ -43,12 +46,31 @@ func setupWebhookForTrainJob(mgr ctrl.Manager, run map[string]runtime.Runtime) e var _ webhook.CustomValidator = (*TrainJobWebhook)(nil) -func (w *TrainJobWebhook) ValidateCreate(context.Context, apiruntime.Object) (admission.Warnings, error) { - return nil, nil +func (w *TrainJobWebhook) ValidateCreate(ctx context.Context, obj apiruntime.Object) (admission.Warnings, error) { + trainJob := obj.(*trainer.TrainJob) + log := ctrl.LoggerFrom(ctx).WithName("trainJob-webhook") + log.V(5).Info("Validating create", "TrainJob", klog.KObj(trainJob)) + runtimeRefGK := runtime.RuntimeRefToRuntimeRegistryKey(trainJob.Spec.RuntimeRef) + runtime, ok := w.runtimes[runtimeRefGK] + if !ok { + return nil, fmt.Errorf("%s: %s", constants.UnsupportedRuntimeErrMsg, runtimeRefGK) + } + warnings, errorList := runtime.ValidateObjects(ctx, nil, trainJob) + return warnings, errorList.ToAggregate() } -func (w *TrainJobWebhook) ValidateUpdate(context.Context, apiruntime.Object, apiruntime.Object) (admission.Warnings, error) { - return nil, nil +func (w *TrainJobWebhook) ValidateUpdate(ctx context.Context, oldObj apiruntime.Object, newObj apiruntime.Object) (admission.Warnings, error) { + oldTrainJob := oldObj.(*trainer.TrainJob) + newTrainJob := newObj.(*trainer.TrainJob) + log := ctrl.LoggerFrom(ctx).WithName("trainJob-webhook") + log.V(5).Info("Validating update", "TrainJob", klog.KObj(newTrainJob)) + runtimeRefGK := runtime.RuntimeRefToRuntimeRegistryKey(newTrainJob.Spec.RuntimeRef) + runtime, ok := w.runtimes[runtimeRefGK] + if !ok { + return nil, fmt.Errorf("%s: %s", constants.UnsupportedRuntimeErrMsg, runtimeRefGK) + } + warnings, errorList := runtime.ValidateObjects(ctx, oldTrainJob, newTrainJob) + return warnings, errorList.ToAggregate() } func (w *TrainJobWebhook) ValidateDelete(context.Context, apiruntime.Object) (admission.Warnings, error) { diff --git a/test/integration/controller/trainjob_controller_test.go b/test/integration/controller/trainjob_controller_test.go index 9378554788..dc10d95888 100644 --- a/test/integration/controller/trainjob_controller_test.go +++ b/test/integration/controller/trainjob_controller_test.go @@ -52,7 +52,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, true) }) ginkgo.AfterAll(func() { fwk.Teardown() @@ -513,11 +513,11 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ordered, func() { var ns *corev1.Namespace - + runtimeName := "training-runtime" ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, false) }) ginkgo.AfterAll(func() { fwk.Teardown() @@ -534,8 +534,36 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord }, } gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) + + baseRuntimeWrapper := testingutil.MakeTrainingRuntimeWrapper(ns.Name, runtimeName) + baseClusterRuntimeWrapper := testingutil.MakeClusterTrainingRuntimeWrapper(runtimeName) + trainingRuntime := baseRuntimeWrapper.RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper( + testingutil.MakeTrainingRuntimeWrapper(ns.Name, runtimeName).Spec).Obj()).Obj() + clusterTrainingRuntime := baseClusterRuntimeWrapper.RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper( + testingutil.MakeClusterTrainingRuntimeWrapper(runtimeName).Spec).Obj()).Obj() + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).To(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, clusterTrainingRuntime)).To(gomega.Succeed()) + + gomega.Eventually(func() error { + err := k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime) + if err != nil { + return err + } + return nil + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Eventually(func() error { + err := k8sClient.Get(ctx, client.ObjectKeyFromObject(clusterTrainingRuntime), clusterTrainingRuntime) + if err != nil { + return err + } + return nil + }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) ginkgo.AfterEach(func() { + gomega.Expect(k8sClient.DeleteAllOf(ctx, &trainer.TrainingRuntime{}, client.InNamespace(ns.Name))).Should(gomega.Succeed()) + gomega.Expect(k8sClient.DeleteAllOf(ctx, &trainer.ClusterTrainingRuntime{})).Should(gomega.Succeed()) gomega.Expect(k8sClient.DeleteAllOf(ctx, &trainer.TrainJob{}, client.InNamespace(ns.Name))).Should(gomega.Succeed()) }) @@ -547,7 +575,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "managed-by-trainjob-controller"). ManagedBy("trainer.kubeflow.org/trainjob-controller"). - RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), "testing"). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). Obj() }, gomega.Succeed()), @@ -555,7 +583,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "managed-by-trainjob-controller"). ManagedBy("kueue.x-k8s.io/multikueue"). - RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), "testing"). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). Obj() }, gomega.Succeed()), @@ -563,7 +591,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "invalid-managed-by"). ManagedBy("invalid"). - RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), "testing"). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). Obj() }, testingutil.BeInvalidError()), @@ -577,53 +605,53 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "null-suspend"). ManagedBy("kueue.x-k8s.io/multikueue"). - RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), "testing"). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), runtimeName). Obj() }, func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "null-suspend"). ManagedBy("kueue.x-k8s.io/multikueue"). - RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), "testing"). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), runtimeName). Suspend(false). Obj() }), ginkgo.Entry("Should succeed to default managedBy=trainer.kubeflow.org/trainjob-controller", func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "null-managed-by"). - RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "testing"). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). Suspend(true). Obj() }, func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "null-managed-by"). ManagedBy("trainer.kubeflow.org/trainjob-controller"). - RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "testing"). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). Suspend(true). Obj() }), ginkgo.Entry("Should succeed to default runtimeRef.apiGroup", func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "empty-api-group"). - RuntimeRef(schema.GroupVersionKind{Group: "", Version: "", Kind: trainer.TrainingRuntimeKind}, "testing"). + RuntimeRef(schema.GroupVersionKind{Group: "", Version: "", Kind: trainer.TrainingRuntimeKind}, runtimeName). Obj() }, func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "empty-api-group"). ManagedBy("trainer.kubeflow.org/trainjob-controller"). - RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "testing"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). Suspend(false). Obj() }), ginkgo.Entry("Should succeed to default runtimeRef.kind", func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "empty-kind"). - RuntimeRef(trainer.SchemeGroupVersion.WithKind(""), "testing"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(""), runtimeName). Obj() }, func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "empty-kind"). ManagedBy("trainer.kubeflow.org/trainjob-controller"). - RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), "testing"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), runtimeName). Suspend(false). Obj() }), @@ -643,7 +671,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "valid-managed-by"). ManagedBy("trainer.kubeflow.org/trainjob-controller"). - RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "testing"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). Obj() }, func(job *trainer.TrainJob) *trainer.TrainJob { @@ -654,7 +682,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord ginkgo.Entry("Should fail to update runtimeRef", func() *trainer.TrainJob { return testingutil.MakeTrainJobWrapper(ns.Name, "valid-runtimeref"). - RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "testing"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). Obj() }, func(job *trainer.TrainJob) *trainer.TrainJob { diff --git a/test/integration/framework/framework.go b/test/integration/framework/framework.go index 1e13d82a7e..c83d3e674c 100644 --- a/test/integration/framework/framework.go +++ b/test/integration/framework/framework.go @@ -42,7 +42,7 @@ import ( schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" - controller "github.com/kubeflow/trainer/pkg/controller" + "github.com/kubeflow/trainer/pkg/controller" runtimecore "github.com/kubeflow/trainer/pkg/runtime/core" kubeflowwebhooks "github.com/kubeflow/trainer/pkg/webhooks" ) @@ -72,7 +72,7 @@ func (f *Framework) Init() *rest.Config { return cfg } -func (f *Framework) RunManager(cfg *rest.Config) (context.Context, client.Client) { +func (f *Framework) RunManager(cfg *rest.Config, startControllers bool) (context.Context, client.Client) { webhookInstallOpts := &f.testEnv.WebhookInstallOptions gomega.ExpectWithOffset(1, trainer.AddToScheme(scheme.Scheme)).NotTo(gomega.HaveOccurred()) gomega.ExpectWithOffset(1, jobsetv1alpha2.AddToScheme(scheme.Scheme)).NotTo(gomega.HaveOccurred()) @@ -114,6 +114,20 @@ func (f *Framework) RunManager(cfg *rest.Config) (context.Context, client.Client gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred(), "controller", failedCtrlName) gomega.ExpectWithOffset(1, failedCtrlName).To(gomega.BeEmpty()) + if startControllers { + failedCtrlName, err := controller.SetupControllers(mgr, runtimes, ctrlpkg.Options{ + // controller-runtime v0.19+ validates controller names are unique, to make sure + // exported Prometheus metrics for each controller do not conflict. The current check + // relies on static state that's not compatible with testing execution model. + // See the following resources for more context: + // https://github.com/kubernetes-sigs/controller-runtime/pull/2902#issuecomment-2284194683 + // https://github.com/kubernetes-sigs/controller-runtime/issues/2994 + SkipNameValidation: ptr.To(true), + }) + gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred(), "controller", failedCtrlName) + gomega.ExpectWithOffset(1, failedCtrlName).To(gomega.BeEmpty()) + } + failedWebhookName, err := kubeflowwebhooks.Setup(mgr, runtimes) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred(), "webhook", failedWebhookName) gomega.ExpectWithOffset(1, failedWebhookName).To(gomega.BeEmpty()) diff --git a/test/integration/webhooks/clustertrainingruntime_webhook_test.go b/test/integration/webhooks/clustertrainingruntime_webhook_test.go index 047f0302c4..9c13287d33 100644 --- a/test/integration/webhooks/clustertrainingruntime_webhook_test.go +++ b/test/integration/webhooks/clustertrainingruntime_webhook_test.go @@ -35,7 +35,7 @@ var _ = ginkgo.Describe("ClusterTrainingRuntime Webhook", ginkgo.Ordered, func() ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, false) }) ginkgo.AfterAll(func() { fwk.Teardown() diff --git a/test/integration/webhooks/trainingruntime_webhook_test.go b/test/integration/webhooks/trainingruntime_webhook_test.go index 839293a3d9..164db69730 100644 --- a/test/integration/webhooks/trainingruntime_webhook_test.go +++ b/test/integration/webhooks/trainingruntime_webhook_test.go @@ -39,7 +39,7 @@ var _ = ginkgo.Describe("TrainingRuntime Webhook", ginkgo.Ordered, func() { ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, false) }) ginkgo.AfterAll(func() { fwk.Teardown() @@ -85,7 +85,7 @@ var _ = ginkgo.Describe("TrainingRuntime marker validations and defaulting", gin ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, false) }) ginkgo.AfterAll(func() { fwk.Teardown() diff --git a/test/integration/webhooks/trainjob_test.go b/test/integration/webhooks/trainjob_test.go index baf9882edc..2b6084a02d 100644 --- a/test/integration/webhooks/trainjob_test.go +++ b/test/integration/webhooks/trainjob_test.go @@ -21,17 +21,27 @@ import ( "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" + testingutil "github.com/kubeflow/trainer/pkg/util/testing" "github.com/kubeflow/trainer/test/integration/framework" + "github.com/kubeflow/trainer/test/util" ) var _ = ginkgo.Describe("TrainJob Webhook", ginkgo.Ordered, func() { var ns *corev1.Namespace + var trainingRuntime *trainer.TrainingRuntime + var clusterTrainingRuntime *trainer.ClusterTrainingRuntime + runtimeName := "training-runtime" + jobName := "train-job" ginkgo.BeforeAll(func() { fwk = &framework.Framework{} cfg = fwk.Init() - ctx, k8sClient = fwk.RunManager(cfg) + ctx, k8sClient = fwk.RunManager(cfg, false) }) ginkgo.AfterAll(func() { fwk.Teardown() @@ -48,5 +58,114 @@ var _ = ginkgo.Describe("TrainJob Webhook", ginkgo.Ordered, func() { }, } gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) + + baseRuntimeWrapper := testingutil.MakeTrainingRuntimeWrapper(ns.Name, runtimeName) + baseClusterRuntimeWrapper := testingutil.MakeClusterTrainingRuntimeWrapper(runtimeName) + trainingRuntime = baseRuntimeWrapper.RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper( + testingutil.MakeTrainingRuntimeWrapper(ns.Name, runtimeName).Spec).Obj()).Obj() + clusterTrainingRuntime = baseClusterRuntimeWrapper.RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper( + testingutil.MakeClusterTrainingRuntimeWrapper(runtimeName).Spec).Obj()).Obj() + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).To(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, clusterTrainingRuntime)).To(gomega.Succeed()) + gomega.Eventually(func() error { + err := k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime) + if err != nil { + return err + } + return nil + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Eventually(func() error { + err := k8sClient.Get(ctx, client.ObjectKeyFromObject(clusterTrainingRuntime), clusterTrainingRuntime) + if err != nil { + return err + } + return nil + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.AfterEach(func() { + gomega.Expect(k8sClient.DeleteAllOf(ctx, &trainer.TrainingRuntime{}, client.InNamespace(ns.Name))).To(gomega.Succeed()) + gomega.Expect(k8sClient.DeleteAllOf(ctx, &trainer.ClusterTrainingRuntime{})).To(gomega.Succeed()) + gomega.Expect(k8sClient.DeleteAllOf(ctx, &trainer.TrainJob{}, client.InNamespace(ns.Name))).To(gomega.Succeed()) + }) + + ginkgo.When("Creating TrainJob", func() { + ginkgo.DescribeTable("Validate TrainJob on creation", func(trainJob func() *trainer.TrainJob, errorMatcher gomega.OmegaMatcher) { + gomega.Expect(k8sClient.Create(ctx, trainJob())).Should(errorMatcher) + }, + ginkgo.Entry("Should succeed in creating trainJob with namespace scoped trainingRuntime", + func() *trainer.TrainJob { + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). + Obj() + }, + gomega.Succeed()), + ginkgo.Entry("Should fail in creating trainJob referencing trainingRuntime not present in the namespace", + func() *trainer.TrainJob { + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), "invalid"). + Obj() + }, + testingutil.BeForbiddenError()), + ginkgo.Entry("Should succeed in creating trainJob with namespace scoped trainingRuntime", + func() *trainer.TrainJob { + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), runtimeName). + Obj() + }, + gomega.Succeed()), + ginkgo.Entry("Should fail in creating trainJob with pre-trained model config when referencing a trainingRuntime without an initializer", + func() *trainer.TrainJob { + trainingRuntime.Spec.Template = trainer.JobSetTemplateSpec{} + gomega.Expect(k8sClient.Update(ctx, trainingRuntime)).To(gomega.Succeed()) + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). + ModelConfig(&trainer.ModelConfig{Input: &trainer.InputModel{}}). + Obj() + }, + testingutil.BeForbiddenError()), + ginkgo.Entry("Should fail in creating trainJob with invalid trainer config for mpi runtime", + func() *trainer.TrainJob { + trainingRuntime.Spec.MLPolicy = &trainer.MLPolicy{MLPolicySource: trainer.MLPolicySource{MPI: &trainer.MPIMLPolicySource{}}} + gomega.Expect(k8sClient.Update(ctx, trainingRuntime)).To(gomega.Succeed()) + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). + Trainer(&trainer.Trainer{NumProcPerNode: ptr.To(intstr.FromString("invalid"))}). + Obj() + }, + testingutil.BeForbiddenError()), + ginkgo.Entry("Should fail in creating trainJob with invalid trainer config for torch runtime", + func() *trainer.TrainJob { + trainingRuntime.Spec.MLPolicy = &trainer.MLPolicy{MLPolicySource: trainer.MLPolicySource{Torch: &trainer.TorchMLPolicySource{}}} + gomega.Expect(k8sClient.Update(ctx, trainingRuntime)).To(gomega.Succeed()) + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). + Trainer(&trainer.Trainer{NumProcPerNode: ptr.To(intstr.FromString("invalid"))}). + Obj() + }, + testingutil.BeForbiddenError()), + ginkgo.Entry("Should succeed in creating trainJob with valid trainer config for torch runtime", + func() *trainer.TrainJob { + trainingRuntime.Spec.MLPolicy = &trainer.MLPolicy{MLPolicySource: trainer.MLPolicySource{Torch: &trainer.TorchMLPolicySource{}}} + gomega.Expect(k8sClient.Update(ctx, trainingRuntime)).To(gomega.Succeed()) + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). + Trainer(&trainer.Trainer{NumProcPerNode: ptr.To(intstr.FromString("auto"))}). + Obj() + }, + gomega.Succeed()), + ginkgo.Entry("Should fail in creating trainJob with trainer config having envs with PET_ prefix", + func() *trainer.TrainJob { + trainingRuntime.Spec.MLPolicy = &trainer.MLPolicy{MLPolicySource: trainer.MLPolicySource{Torch: &trainer.TorchMLPolicySource{}}} + gomega.Expect(k8sClient.Update(ctx, trainingRuntime)).To(gomega.Succeed()) + return testingutil.MakeTrainJobWrapper(ns.Name, jobName). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), runtimeName). + Trainer(&trainer.Trainer{Env: []corev1.EnvVar{{Name: "PET_X", Value: "test"}}}). + Obj() + }, + testingutil.BeForbiddenError()), + ) }) })