Skip to content

Commit

Permalink
KEP-2170: Use SSA to reconcile TrainJob components
Browse files Browse the repository at this point in the history
Signed-off-by: Antonin Stefanutti <[email protected]>
  • Loading branch information
astefanutti committed Feb 11, 2025
1 parent cc40702 commit 81c566f
Show file tree
Hide file tree
Showing 12 changed files with 229 additions and 242 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ require (
sigs.k8s.io/controller-runtime v0.19.1
sigs.k8s.io/jobset v0.5.2
sigs.k8s.io/kueue v0.6.3
sigs.k8s.io/scheduler-plugins v0.28.9
sigs.k8s.io/scheduler-plugins v0.30.6
sigs.k8s.io/structured-merge-diff/v4 v4.4.1
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMm
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/kueue v0.6.3 h1:PmccdKPDFQIaAboyuSG6M0w6hXtxVA51RV+DjCUtBtQ=
sigs.k8s.io/kueue v0.6.3/go.mod h1:rliYfK/K7pJ7CT4ReV1szzciNkAo3sBn5Bmr5Sn6uCY=
sigs.k8s.io/scheduler-plugins v0.28.9 h1:1/bXRoXuSUFr1FLqxrzScdyZMl/G1psuDJcDKYxTo+Q=
sigs.k8s.io/scheduler-plugins v0.28.9/go.mod h1:32+kIPGT0aTRsEDzKNga7zCbcCHK0dSk5UFCY+gzCLE=
sigs.k8s.io/scheduler-plugins v0.30.6 h1:P4pViMVoyVNHWmkG96UtJ4LvxkUIeenIUKLZd09vDyw=
sigs.k8s.io/scheduler-plugins v0.30.6/go.mod h1:EDYYqHmpHR//VYKAeud1TTQbTFSvpdGFeyEg9ejOmnI=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4=
sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08=
sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=
Expand Down
26 changes: 6 additions & 20 deletions pkg/controller/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,18 @@ func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobru
if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil {
return buildFailed, err
}

logKeysAndValues := []any{
"groupVersionKind", gvk.String(),
"namespace", obj.GetNamespace(),
"name", obj.GetName(),
}
// TODO (tenzen-y): Ideally, we should use the SSA instead of checking existence.
// Non-empty resourceVersion indicates UPDATE operation.
var creationErr error
var created bool
if obj.GetResourceVersion() == "" {
creationErr = r.client.Create(ctx, obj)
created = creationErr == nil
}
switch {
case created:
log.V(5).Info("Succeeded to create object", logKeysAndValues...)
continue
case client.IgnoreAlreadyExists(creationErr) != nil:
return creationFailed, creationErr
default:
// This indicates CREATE operation has not been performed or the object has already existed in the cluster.
if err = r.client.Update(ctx, obj); err != nil {
return updateFailed, err
}
log.V(5).Info("Succeeded to update object", logKeysAndValues...)

if err := r.client.Patch(ctx, obj, client.Apply, client.FieldOwner("trainer"), client.ForceOwnership); err != nil {
return buildFailed, err
}

log.V(5).Info("Succeeded to update object", logKeysAndValues...)
}
return creationSucceeded, nil
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/runtime/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -52,7 +53,7 @@ func NewClusterTrainingRuntime(context.Context, client.Client, client.FieldIndex
}, nil
}

func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]client.Object, error) {
func (r *ClusterTrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) {
var clTrainingRuntime trainer.ClusterTrainingRuntime
if err := r.client.Get(ctx, client.ObjectKey{Name: trainJob.Spec.RuntimeRef.Name}, &clTrainingRuntime); err != nil {
return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedClusterTrainingRuntime, err)
Expand Down
12 changes: 4 additions & 8 deletions pkg/runtime/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import (
"fmt"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/runtime"
Expand Down Expand Up @@ -71,7 +71,7 @@ func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.Fie
return trainingRuntimeFactory, nil
}

func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]client.Object, error) {
func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) {
var trainingRuntime trainer.TrainingRuntime
err := r.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.RuntimeRef.Name}, &trainingRuntime)
if err != nil {
Expand All @@ -82,7 +82,7 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.Trai

func (r *TrainingRuntime) buildObjects(
ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy,
) ([]client.Object, error) {
) ([]*unstructured.Unstructured, error) {
propagationLabels := jobSetTemplateSpec.Labels
if propagationLabels == nil && trainJob.Spec.Labels != nil {
propagationLabels = make(map[string]string, len(trainJob.Spec.Labels))
Expand Down Expand Up @@ -121,11 +121,7 @@ func (r *TrainingRuntime) buildObjects(
return nil, err
}

jobSetTemplate := jobsetv1alpha2.JobSet{
Spec: jobSetTemplateSpec.Spec,
}

return r.framework.RunComponentBuilderPlugins(ctx, jobSetTemplate.DeepCopy(), info, trainJob)
return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob)
}

func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
Expand Down
27 changes: 15 additions & 12 deletions pkg/runtime/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ import (
"context"
"errors"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/runtime"
"github.com/kubeflow/trainer/pkg/runtime/framework"
fwkplugins "github.com/kubeflow/trainer/pkg/runtime/framework/plugins"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
k8sruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
)

var errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered")
Expand Down Expand Up @@ -112,15 +113,17 @@ func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *trainer.TrainJob)
return aggregatedWarnings, aggregatedErrors
}

func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *trainer.TrainJob) ([]client.Object, error) {
var objs []client.Object
func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]*unstructured.Unstructured, error) {
var objs []*unstructured.Unstructured
for _, plugin := range f.componentBuilderPlugins {
obj, err := plugin.Build(ctx, runtimeJobTemplate, info, trainJob)
if err != nil {
if component, err := plugin.Build(ctx, info, trainJob); err != nil {
return nil, err
}
if obj != nil {
objs = append(objs, obj)
} else if component != nil {
if content, err := k8sruntime.DefaultUnstructuredConverter.ToUnstructured(component); err != nil {
return nil, err
} else {
objs = append(objs, &unstructured.Unstructured{Object: content})
}
}
}
return objs, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
if err = fwk.RunEnforceMLPolicyPlugins(tc.runtimeInfo, tc.trainJob); err != nil {
t.Fatal(err)
}
objs, err := fwk.RunComponentBuilderPlugins(ctx, tc.runtimeJobTemplate, tc.runtimeInfo, tc.trainJob)
objs, err := fwk.RunComponentBuilderPlugins(ctx, tc.runtimeInfo, tc.trainJob)
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected errors (-want,+got):\n%s", diff)
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/runtime/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
Expand Down Expand Up @@ -54,7 +53,7 @@ type CustomValidationPlugin interface {

type ComponentBuilderPlugin interface {
Plugin
Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *trainer.TrainJob) (client.Object, error)
Build(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) (any, error)
}

type TerminalConditionPlugin interface {
Expand Down
68 changes: 23 additions & 45 deletions pkg/runtime/framework/plugins/coscheduling/coscheduling.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,31 @@ import (
"context"
"errors"
"fmt"
"maps"
"slices"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/runtime"
"github.com/kubeflow/trainer/pkg/runtime/framework"
runtimeindexer "github.com/kubeflow/trainer/pkg/runtime/indexer"
corev1 "k8s.io/api/core/v1"
nodev1 "k8s.io/api/node/v1"
"k8s.io/apimachinery/pkg/api/equality"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
apiruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
metav1ac "k8s.io/client-go/applyconfigurations/meta/v1"
"k8s.io/client-go/util/workqueue"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"
ctrlutil "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
"sigs.k8s.io/controller-runtime/pkg/source"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/constants"
"github.com/kubeflow/trainer/pkg/runtime"
"github.com/kubeflow/trainer/pkg/runtime/framework"
runtimeindexer "github.com/kubeflow/trainer/pkg/runtime/indexer"
schedulerpluginsv1alpha1ac "sigs.k8s.io/scheduler-plugins/pkg/generated/applyconfiguration/scheduling/v1alpha1"
)

type CoScheduling struct {
Expand Down Expand Up @@ -103,7 +98,7 @@ func (c *CoScheduling) EnforcePodGroupPolicy(info *runtime.Info, trainJob *train
return nil
}

func (c *CoScheduling) Build(ctx context.Context, _ client.Object, info *runtime.Info, trainJob *trainer.TrainJob) (client.Object, error) {
func (c *CoScheduling) Build(_ context.Context, info *runtime.Info, trainJob *trainer.TrainJob) (any, error) {
if info == nil || info.RuntimePolicy.PodGroupPolicy == nil || info.RuntimePolicy.PodGroupPolicy.Coscheduling == nil || trainJob == nil {
return nil, nil
}
Expand All @@ -119,40 +114,23 @@ func (c *CoScheduling) Build(ctx context.Context, _ client.Object, info *runtime
totalResources[resName] = current
}
}
newPG := &schedulerpluginsv1alpha1.PodGroup{
TypeMeta: metav1.TypeMeta{
APIVersion: schedulerpluginsv1alpha1.SchemeGroupVersion.String(),
Kind: constants.PodGroupKind,
},
ObjectMeta: metav1.ObjectMeta{
Name: trainJob.Name,
Namespace: trainJob.Namespace,
},
Spec: schedulerpluginsv1alpha1.PodGroupSpec{
ScheduleTimeoutSeconds: info.RuntimePolicy.PodGroupPolicy.Coscheduling.ScheduleTimeoutSeconds,
MinMember: totalMembers,
MinResources: totalResources,
},
}
if err := ctrlutil.SetControllerReference(trainJob, newPG, c.scheme); err != nil {
return nil, err
}
oldPG := &schedulerpluginsv1alpha1.PodGroup{}
if err := c.client.Get(ctx, client.ObjectKeyFromObject(newPG), oldPG); err != nil {
if !apierrors.IsNotFound(err) {
return nil, err
}
oldPG = nil
}
if needsCreateOrUpdate(oldPG, newPG, ptr.Deref(trainJob.Spec.Suspend, false)) {
return newPG, nil
}
return nil, nil
}

func needsCreateOrUpdate(old, new *schedulerpluginsv1alpha1.PodGroup, suspended bool) bool {
return old == nil ||
suspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations))
podGroup := schedulerpluginsv1alpha1ac.PodGroup(trainJob.Name, trainJob.Namespace)

podGroup.WithSpec(schedulerpluginsv1alpha1ac.PodGroupSpec().
WithScheduleTimeoutSeconds(totalMembers).
WithMinResources(totalResources).
WithScheduleTimeoutSeconds(*info.RuntimePolicy.PodGroupPolicy.Coscheduling.ScheduleTimeoutSeconds))

podGroup.WithOwnerReferences(metav1ac.OwnerReference().
WithAPIVersion(trainer.GroupVersion.String()).
WithKind(trainer.TrainJobKind).
WithName(trainJob.Name).
WithUID(trainJob.UID).
WithController(true).
WithBlockOwnerDeletion(true))

return podGroup, nil
}

type PodGroupRuntimeClassHandler struct {
Expand Down
Loading

0 comments on commit 81c566f

Please sign in to comment.