diff --git a/pkg/controller/workload/mpijob/mpijob_controller.go b/pkg/controller/workload/mpijob/mpijob_controller.go index 4f863131a6..99693cd9a4 100644 --- a/pkg/controller/workload/mpijob/mpijob_controller.go +++ b/pkg/controller/workload/mpijob/mpijob_controller.go @@ -42,6 +42,7 @@ import ( "sigs.k8s.io/kueue/pkg/constants" "sigs.k8s.io/kueue/pkg/controller/workload/common" workload_common "sigs.k8s.io/kueue/pkg/controller/workload/common" + utilpriority "sigs.k8s.io/kueue/pkg/util/priority" "sigs.k8s.io/kueue/pkg/workload" ) @@ -461,6 +462,9 @@ func ConstructWorkloadFor(ctx context.Context, client client.Client, }, } + highestPriorityClassName := "" + highestPriority := int32(0) + for index, mpiReplicaType := range orderedReplicaTypes(&job.Spec) { podSet := kueue.PodSet{ Name: strconv.Itoa(index), @@ -468,7 +472,21 @@ func ConstructWorkloadFor(ctx context.Context, client client.Client, Count: podsCount(&job.Spec, mpiReplicaType), } w.Spec.PodSets = append(w.Spec.PodSets, podSet) + + priorityClassName, p, err := utilpriority.GetPriorityFromPriorityClass( + ctx, client, job.Spec.MPIReplicaSpecs[mpiReplicaType].Template.Spec.PriorityClassName) + if err != nil { + return nil, err + } + if p > highestPriority { + highestPriority = p + highestPriorityClassName = priorityClassName + } } + // Populate priority from priority class. + w.Spec.Priority = &highestPriority + w.Spec.PriorityClassName = highestPriorityClassName + if err := ctrl.SetControllerReference(job, w, scheme); err != nil { return nil, err } diff --git a/pkg/util/testing/wrappers_mpijob.go b/pkg/util/testing/wrappers_mpijob.go index a65885721d..860bd61543 100644 --- a/pkg/util/testing/wrappers_mpijob.go +++ b/pkg/util/testing/wrappers_mpijob.go @@ -79,6 +79,14 @@ func MakeMPIJob(name, ns string) *MPIJobWrapper { }} } +// PriorityClass updates job priorityclass. +func (j *MPIJobWrapper) PriorityClass(pc string) *MPIJobWrapper { + for replicaType := range j.Spec.MPIReplicaSpecs { + j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.PriorityClassName = pc + } + return j +} + // Obj returns the inner Job. func (j *MPIJobWrapper) Obj() *kubeflow.MPIJob { return &j.MPIJob diff --git a/test/integration/controller/mpijob/mpijob_controller_test.go b/test/integration/controller/mpijob/mpijob_controller_test.go index 10a8421161..7de3b6d33d 100644 --- a/test/integration/controller/mpijob/mpijob_controller_test.go +++ b/test/integration/controller/mpijob/mpijob_controller_test.go @@ -17,13 +17,18 @@ limitations under the License. package mpijob import ( + "fmt" + "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/scheme" kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" kueue "sigs.k8s.io/kueue/apis/kueue/v1alpha2" + "sigs.k8s.io/kueue/pkg/constants" "sigs.k8s.io/kueue/pkg/controller/workload/common" workloadmpijob "sigs.k8s.io/kueue/pkg/controller/workload/mpijob" "sigs.k8s.io/kueue/pkg/util/testing" @@ -32,9 +37,12 @@ import ( ) const ( - jobName = "test-job" - jobNamespace = "default" - jobKey = jobNamespace + "/" + jobName + jobName = "test-job" + jobNamespace = "default" + labelKey = "cloud.provider.com/instance" + jobKey = jobNamespace + "/" + jobName + priorityClassName = "test-priority-class" + priorityValue = 10 ) var ( @@ -60,7 +68,11 @@ var _ = ginkgo.Describe("Job controller", func() { ginkgo.It("Should reconcile workload and job client", func() { ginkgo.By("checking the job gets suspended when created unsuspended") - job := testing.MakeMPIJob(jobName, jobNamespace).Obj() + priorityClass := testing.MakePriorityClass(priorityClassName). + PriorityValue(int32(priorityValue)).Obj() + gomega.Expect(k8sClient.Create(ctx, priorityClass)).Should(gomega.Succeed()) + + job := testing.MakeMPIJob(jobName, jobNamespace).PriorityClass(priorityClassName).Obj() err := k8sClient.Create(ctx, job) gomega.Expect(err).To(gomega.Succeed()) createdJob := &kubeflow.MPIJob{} @@ -79,5 +91,99 @@ var _ = ginkgo.Describe("Job controller", func() { }, util.Timeout, util.Interval).Should(gomega.Succeed()) gomega.Expect(createdWorkload.Spec.QueueName).Should(gomega.Equal(""), "The Workload shouldn't have .spec.queueName set") gomega.Expect(metav1.IsControlledBy(createdWorkload, createdJob)).To(gomega.BeTrue(), "The Workload should be owned by the Job") + + ginkgo.By("checking the workload is created with priority and priorityName") + gomega.Expect(createdWorkload.Spec.PriorityClassName).Should(gomega.Equal(priorityClassName)) + gomega.Expect(*createdWorkload.Spec.Priority).Should(gomega.Equal(int32(priorityValue))) + + ginkgo.By("checking the workload is updated with queue name when the job does") + jobQueueName := "test-queue" + createdJob.Annotations = map[string]string{constants.QueueAnnotation: jobQueueName} + gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed()) + gomega.Eventually(func() bool { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { + return false + } + return createdWorkload.Spec.QueueName == jobQueueName + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + + ginkgo.By("checking a second non-matching workload is deleted") + secondWl, _ := workloadmpijob.ConstructWorkloadFor(ctx, k8sClient, createdJob, scheme.Scheme) + secondWl.Name = common.GetNameForJob("second-workload") + secondWl.Spec.PodSets[0].Count += 1 + gomega.Expect(k8sClient.Create(ctx, secondWl)).Should(gomega.Succeed()) + gomega.Eventually(func() error { + wl := &kueue.Workload{} + key := types.NamespacedName{Name: secondWl.Name, Namespace: secondWl.Namespace} + return k8sClient.Get(ctx, key, wl) + }, util.Timeout, util.Interval).Should(testing.BeNotFoundError()) + // check the original wl is still there + gomega.Consistently(func() bool { + err := k8sClient.Get(ctx, wlLookupKey, createdWorkload) + return err == nil + }, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue()) + + ginkgo.By("checking the job is unsuspended when workload is assigned") + onDemandFlavor := testing.MakeResourceFlavor("on-demand").Label(labelKey, "on-demand").Obj() + gomega.Expect(k8sClient.Create(ctx, onDemandFlavor)).Should(gomega.Succeed()) + spotFlavor := testing.MakeResourceFlavor("spot").Label(labelKey, "spot").Obj() + gomega.Expect(k8sClient.Create(ctx, spotFlavor)).Should(gomega.Succeed()) + clusterQueue := testing.MakeClusterQueue("cluster-queue"). + Resource(testing.MakeResource(corev1.ResourceCPU). + Flavor(testing.MakeFlavor(onDemandFlavor.Name, "5").Obj()). + Flavor(testing.MakeFlavor(spotFlavor.Name, "5").Obj()). + Obj()).Obj() + createdWorkload.Spec.Admission = &kueue.Admission{ + ClusterQueue: kueue.ClusterQueueReference(clusterQueue.Name), + PodSetFlavors: []kueue.PodSetFlavors{{ + Name: "Launcher", + Flavors: map[corev1.ResourceName]string{ + corev1.ResourceCPU: onDemandFlavor.Name, + }, + }, { + Name: "Worker", + Flavors: map[corev1.ResourceName]string{ + corev1.ResourceCPU: onDemandFlavor.Name, + }, + }}, + } + lookupKey := types.NamespacedName{Name: jobName, Namespace: jobNamespace} + gomega.Expect(k8sClient.Update(ctx, createdWorkload)).Should(gomega.Succeed()) + gomega.Eventually(func() bool { + if err := k8sClient.Get(ctx, lookupKey, createdJob); err != nil { + return false + } + return !*createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + gomega.Eventually(func() bool { + ok, _ := testing.CheckLatestEvent(ctx, k8sClient, "Started", corev1.EventTypeNormal, fmt.Sprintf("Admitted by clusterQueue %v", clusterQueue.Name)) + return ok + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + gomega.Expect(len(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template.Spec.NodeSelector)).Should(gomega.Equal(1)) + gomega.Expect(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template.Spec.NodeSelector[labelKey]).Should(gomega.Equal(onDemandFlavor.Name)) + gomega.Consistently(func() bool { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { + return false + } + return len(createdWorkload.Status.Conditions) == 0 + }, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue()) + + ginkgo.By("checking the workload is finished when job is completed") + createdJob.Status.Conditions = append(createdJob.Status.Conditions, + kubeflow.JobCondition{ + Type: kubeflow.JobSucceeded, + Status: corev1.ConditionTrue, + LastTransitionTime: metav1.Now(), + }) + gomega.Expect(k8sClient.Status().Update(ctx, createdJob)).Should(gomega.Succeed()) + gomega.Eventually(func() bool { + err := k8sClient.Get(ctx, wlLookupKey, createdWorkload) + if err != nil || len(createdWorkload.Status.Conditions) == 0 { + return false + } + + return createdWorkload.Status.Conditions[0].Type == kueue.WorkloadFinished && + createdWorkload.Status.Conditions[0].Status == metav1.ConditionTrue + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) }) })