Skip to content

Commit

Permalink
Support priorities and more integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mimowo committed Feb 21, 2023
1 parent 9a0a282 commit 2f8b14b
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 4 deletions.
18 changes: 18 additions & 0 deletions pkg/controller/workload/mpijob/mpijob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"sigs.k8s.io/kueue/pkg/constants"
"sigs.k8s.io/kueue/pkg/controller/workload/common"
workload_common "sigs.k8s.io/kueue/pkg/controller/workload/common"
utilpriority "sigs.k8s.io/kueue/pkg/util/priority"
"sigs.k8s.io/kueue/pkg/workload"
)

Expand Down Expand Up @@ -461,14 +462,31 @@ func ConstructWorkloadFor(ctx context.Context, client client.Client,
},
}

highestPriorityClassName := ""
highestPriority := int32(0)

for index, mpiReplicaType := range orderedReplicaTypes(&job.Spec) {
podSet := kueue.PodSet{
Name: strconv.Itoa(index),
Spec: *job.Spec.MPIReplicaSpecs[mpiReplicaType].Template.Spec.DeepCopy(),
Count: podsCount(&job.Spec, mpiReplicaType),
}
w.Spec.PodSets = append(w.Spec.PodSets, podSet)

priorityClassName, p, err := utilpriority.GetPriorityFromPriorityClass(
ctx, client, job.Spec.MPIReplicaSpecs[mpiReplicaType].Template.Spec.PriorityClassName)
if err != nil {
return nil, err
}
if p > highestPriority {
highestPriority = p
highestPriorityClassName = priorityClassName
}
}
// Populate priority from priority class.
w.Spec.Priority = &highestPriority
w.Spec.PriorityClassName = highestPriorityClassName

if err := ctrl.SetControllerReference(job, w, scheme); err != nil {
return nil, err
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/util/testing/wrappers_mpijob.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ func MakeMPIJob(name, ns string) *MPIJobWrapper {
}}
}

// PriorityClass updates job priorityclass.
func (j *MPIJobWrapper) PriorityClass(pc string) *MPIJobWrapper {
for replicaType := range j.Spec.MPIReplicaSpecs {
j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.PriorityClassName = pc
}
return j
}

// Obj returns the inner Job.
func (j *MPIJobWrapper) Obj() *kubeflow.MPIJob {
return &j.MPIJob
Expand Down
114 changes: 110 additions & 4 deletions test/integration/controller/mpijob/mpijob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@ limitations under the License.
package mpijob

import (
"fmt"

"github.com/onsi/ginkgo/v2"
"github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes/scheme"

kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1"
kueue "sigs.k8s.io/kueue/apis/kueue/v1alpha2"
"sigs.k8s.io/kueue/pkg/constants"
"sigs.k8s.io/kueue/pkg/controller/workload/common"
workloadmpijob "sigs.k8s.io/kueue/pkg/controller/workload/mpijob"
"sigs.k8s.io/kueue/pkg/util/testing"
Expand All @@ -32,9 +37,12 @@ import (
)

const (
jobName = "test-job"
jobNamespace = "default"
jobKey = jobNamespace + "/" + jobName
jobName = "test-job"
jobNamespace = "default"
labelKey = "cloud.provider.com/instance"
jobKey = jobNamespace + "/" + jobName
priorityClassName = "test-priority-class"
priorityValue = 10
)

var (
Expand All @@ -60,7 +68,11 @@ var _ = ginkgo.Describe("Job controller", func() {

ginkgo.It("Should reconcile workload and job client", func() {
ginkgo.By("checking the job gets suspended when created unsuspended")
job := testing.MakeMPIJob(jobName, jobNamespace).Obj()
priorityClass := testing.MakePriorityClass(priorityClassName).
PriorityValue(int32(priorityValue)).Obj()
gomega.Expect(k8sClient.Create(ctx, priorityClass)).Should(gomega.Succeed())

job := testing.MakeMPIJob(jobName, jobNamespace).PriorityClass(priorityClassName).Obj()
err := k8sClient.Create(ctx, job)
gomega.Expect(err).To(gomega.Succeed())
createdJob := &kubeflow.MPIJob{}
Expand All @@ -79,5 +91,99 @@ var _ = ginkgo.Describe("Job controller", func() {
}, util.Timeout, util.Interval).Should(gomega.Succeed())
gomega.Expect(createdWorkload.Spec.QueueName).Should(gomega.Equal(""), "The Workload shouldn't have .spec.queueName set")
gomega.Expect(metav1.IsControlledBy(createdWorkload, createdJob)).To(gomega.BeTrue(), "The Workload should be owned by the Job")

ginkgo.By("checking the workload is created with priority and priorityName")
gomega.Expect(createdWorkload.Spec.PriorityClassName).Should(gomega.Equal(priorityClassName))
gomega.Expect(*createdWorkload.Spec.Priority).Should(gomega.Equal(int32(priorityValue)))

ginkgo.By("checking the workload is updated with queue name when the job does")
jobQueueName := "test-queue"
createdJob.Annotations = map[string]string{constants.QueueAnnotation: jobQueueName}
gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed())
gomega.Eventually(func() bool {
if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil {
return false
}
return createdWorkload.Spec.QueueName == jobQueueName
}, util.Timeout, util.Interval).Should(gomega.BeTrue())

ginkgo.By("checking a second non-matching workload is deleted")
secondWl, _ := workloadmpijob.ConstructWorkloadFor(ctx, k8sClient, createdJob, scheme.Scheme)
secondWl.Name = common.GetNameForJob("second-workload")
secondWl.Spec.PodSets[0].Count += 1
gomega.Expect(k8sClient.Create(ctx, secondWl)).Should(gomega.Succeed())
gomega.Eventually(func() error {
wl := &kueue.Workload{}
key := types.NamespacedName{Name: secondWl.Name, Namespace: secondWl.Namespace}
return k8sClient.Get(ctx, key, wl)
}, util.Timeout, util.Interval).Should(testing.BeNotFoundError())
// check the original wl is still there
gomega.Consistently(func() bool {
err := k8sClient.Get(ctx, wlLookupKey, createdWorkload)
return err == nil
}, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue())

ginkgo.By("checking the job is unsuspended when workload is assigned")
onDemandFlavor := testing.MakeResourceFlavor("on-demand").Label(labelKey, "on-demand").Obj()
gomega.Expect(k8sClient.Create(ctx, onDemandFlavor)).Should(gomega.Succeed())
spotFlavor := testing.MakeResourceFlavor("spot").Label(labelKey, "spot").Obj()
gomega.Expect(k8sClient.Create(ctx, spotFlavor)).Should(gomega.Succeed())
clusterQueue := testing.MakeClusterQueue("cluster-queue").
Resource(testing.MakeResource(corev1.ResourceCPU).
Flavor(testing.MakeFlavor(onDemandFlavor.Name, "5").Obj()).
Flavor(testing.MakeFlavor(spotFlavor.Name, "5").Obj()).
Obj()).Obj()
createdWorkload.Spec.Admission = &kueue.Admission{
ClusterQueue: kueue.ClusterQueueReference(clusterQueue.Name),
PodSetFlavors: []kueue.PodSetFlavors{{
Name: "Launcher",
Flavors: map[corev1.ResourceName]string{
corev1.ResourceCPU: onDemandFlavor.Name,
},
}, {
Name: "Worker",
Flavors: map[corev1.ResourceName]string{
corev1.ResourceCPU: onDemandFlavor.Name,
},
}},
}
lookupKey := types.NamespacedName{Name: jobName, Namespace: jobNamespace}
gomega.Expect(k8sClient.Update(ctx, createdWorkload)).Should(gomega.Succeed())
gomega.Eventually(func() bool {
if err := k8sClient.Get(ctx, lookupKey, createdJob); err != nil {
return false
}
return !*createdJob.Spec.RunPolicy.Suspend
}, util.Timeout, util.Interval).Should(gomega.BeTrue())
gomega.Eventually(func() bool {
ok, _ := testing.CheckLatestEvent(ctx, k8sClient, "Started", corev1.EventTypeNormal, fmt.Sprintf("Admitted by clusterQueue %v", clusterQueue.Name))
return ok
}, util.Timeout, util.Interval).Should(gomega.BeTrue())
gomega.Expect(len(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template.Spec.NodeSelector)).Should(gomega.Equal(1))
gomega.Expect(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template.Spec.NodeSelector[labelKey]).Should(gomega.Equal(onDemandFlavor.Name))
gomega.Consistently(func() bool {
if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil {
return false
}
return len(createdWorkload.Status.Conditions) == 0
}, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue())

ginkgo.By("checking the workload is finished when job is completed")
createdJob.Status.Conditions = append(createdJob.Status.Conditions,
kubeflow.JobCondition{
Type: kubeflow.JobSucceeded,
Status: corev1.ConditionTrue,
LastTransitionTime: metav1.Now(),
})
gomega.Expect(k8sClient.Status().Update(ctx, createdJob)).Should(gomega.Succeed())
gomega.Eventually(func() bool {
err := k8sClient.Get(ctx, wlLookupKey, createdWorkload)
if err != nil || len(createdWorkload.Status.Conditions) == 0 {
return false
}

return createdWorkload.Status.Conditions[0].Type == kueue.WorkloadFinished &&
createdWorkload.Status.Conditions[0].Status == metav1.ConditionTrue
}, util.Timeout, util.Interval).Should(gomega.BeTrue())
})
})

0 comments on commit 2f8b14b

Please sign in to comment.