Skip to content

Commit

Permalink
Address review feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Antonin Stefanutti <[email protected]>
  • Loading branch information
astefanutti committed Feb 25, 2025
1 parent 29ab104 commit be59e6b
Show file tree
Hide file tree
Showing 11 changed files with 28 additions and 29 deletions.
26 changes: 13 additions & 13 deletions pkg/util/apply/apply.go → pkg/apply/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,33 @@ import (
"k8s.io/utils/ptr"
)

func UpsertEnvVar(envVarList *[]corev1ac.EnvVarApplyConfiguration, envVar ...*corev1ac.EnvVarApplyConfiguration) {
func UpsertEnvVar(envVars *[]corev1ac.EnvVarApplyConfiguration, envVar ...*corev1ac.EnvVarApplyConfiguration) {
for _, e := range envVar {
upsert(envVarList, *e, byEnvVarName)
upsert(envVars, *e, byEnvVarName)
}
}

func UpsertEnvVars(envVarList *[]corev1ac.EnvVarApplyConfiguration, envVars []corev1ac.EnvVarApplyConfiguration) {
for _, e := range envVars {
upsert(envVarList, e, byEnvVarName)
func UpsertEnvVars(envVars *[]corev1ac.EnvVarApplyConfiguration, upEnvVars []corev1ac.EnvVarApplyConfiguration) {
for _, e := range upEnvVars {
upsert(envVars, e, byEnvVarName)
}
}

func UpsertPort(portList *[]corev1ac.ContainerPortApplyConfiguration, port ...*corev1ac.ContainerPortApplyConfiguration) {
func UpsertPort(ports *[]corev1ac.ContainerPortApplyConfiguration, port ...*corev1ac.ContainerPortApplyConfiguration) {
for _, p := range port {
upsert(portList, *p, byContainerPortOrName)
upsert(ports, *p, byContainerPortOrName)
}
}

func UpsertVolumes(volumeList *[]corev1ac.VolumeApplyConfiguration, volumes []corev1ac.VolumeApplyConfiguration) {
for _, v := range volumes {
upsert(volumeList, v, byVolumeName)
func UpsertVolumes(volumes *[]corev1ac.VolumeApplyConfiguration, upVolumes []corev1ac.VolumeApplyConfiguration) {
for _, v := range upVolumes {
upsert(volumes, v, byVolumeName)
}
}

func UpsertVolumeMounts(mountList *[]corev1ac.VolumeMountApplyConfiguration, mounts []corev1ac.VolumeMountApplyConfiguration) {
for _, m := range mounts {
upsert(mountList, m, byVolumeMountName)
func UpsertVolumeMounts(mounts *[]corev1ac.VolumeMountApplyConfiguration, upMounts []corev1ac.VolumeMountApplyConfiguration) {
for _, m := range upMounts {
upsert(mounts, m, byVolumeMountName)
}
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/controller/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
k8sruntime "k8s.io/apimachinery/pkg/runtime"
apiruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
Expand Down Expand Up @@ -117,9 +117,9 @@ func (r *TrainJobReconciler) reconcileObjects(ctx context.Context, runtime jobru
// client. See https://github.com/kubernetes/kubernetes/pull/129313
var obj client.Object
if o, ok := object.(client.Object); ok {
obj = o
return buildFailed, fmt.Errorf("unsupported type client.Object for component: %v", o)
} else {
if u, err := k8sruntime.DefaultUnstructuredConverter.ToUnstructured(object); err != nil {
if u, err := apiruntime.DefaultUnstructuredConverter.ToUnstructured(object); err != nil {
return buildFailed, err
} else {
obj = &unstructured.Unstructured{Object: u}
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/core/clustertrainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {

resultObjs, err := testingutil.ToObject(c.Scheme(), objs...)
if err != nil {
t.Fatal(err)
t.Errorf("Pipeline built unrecognizable objects: %v", err)
}

if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {

resultObjs, err := testingutil.ToObject(c.Scheme(), objs...)
if err != nil {
t.Fatal(err)
t.Errorf("Pipeline built unrecognizable objects: %v", err)
}

if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtim
for _, plugin := range f.componentBuilderPlugins {
if components, err := plugin.Build(ctx, info, trainJob); err != nil {
return nil, err
} else if components != nil {
} else {
objs = append(objs, components...)
}
}
Expand Down
9 changes: 4 additions & 5 deletions pkg/runtime/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
apiruntime "k8s.io/apimachinery/pkg/runtime"
k8sruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -395,7 +394,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
trainingRuntime *trainer.TrainingRuntime
trainJob *trainer.TrainJob
wantRuntimeInfo *runtime.Info
wantObjs []k8sruntime.Object
wantObjs []apiruntime.Object
wantError error
}{
"succeeded to build PodGroup and JobSet with NumNodes from TrainJob": {
Expand Down Expand Up @@ -470,7 +469,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
},
},
},
wantObjs: []k8sruntime.Object{
wantObjs: []apiruntime.Object{
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
SchedulingTimeout(300).
MinMember(101). // 101 replicas = 100 Trainer nodes + 1 Initializer.
Expand All @@ -491,7 +490,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
"an empty registry": {},
}
cmpOpts := []cmp.Option{
cmpopts.SortSlices(func(a, b k8sruntime.Object) bool {
cmpopts.SortSlices(func(a, b apiruntime.Object) bool {
return a.GetObjectKind().GroupVersionKind().String() < b.GetObjectKind().GroupVersionKind().String()
}),
cmpopts.EquateEmpty(),
Expand Down Expand Up @@ -530,7 +529,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {

resultObjs, err := testingutil.ToObject(c.Scheme(), objs...)
if err != nil {
t.Fatal(err)
t.Errorf("Pipeline built unrecognizable objects: %v", err)
}

if diff := cmp.Diff(tc.wantObjs, resultObjs, cmpOpts...); len(diff) != 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/plugins/jobset/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import (
jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/apply"
"github.com/kubeflow/trainer/pkg/constants"
"github.com/kubeflow/trainer/pkg/runtime"
"github.com/kubeflow/trainer/pkg/util/apply"
)

type Builder struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/apply"
"github.com/kubeflow/trainer/pkg/constants"
"github.com/kubeflow/trainer/pkg/runtime"
"github.com/kubeflow/trainer/pkg/runtime/framework"
"github.com/kubeflow/trainer/pkg/util/apply"
)

type MPI struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/plugins/plainml/plainml.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/apply"
"github.com/kubeflow/trainer/pkg/constants"
"github.com/kubeflow/trainer/pkg/runtime"
"github.com/kubeflow/trainer/pkg/runtime/framework"
"github.com/kubeflow/trainer/pkg/util/apply"
)

var _ framework.EnforceMLPolicyPlugin = (*PlainML)(nil)
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/apply"
"github.com/kubeflow/trainer/pkg/constants"
"github.com/kubeflow/trainer/pkg/runtime"
"github.com/kubeflow/trainer/pkg/runtime/framework"
"github.com/kubeflow/trainer/pkg/util/apply"
)

type Torch struct{}
Expand Down
2 changes: 1 addition & 1 deletion pkg/util/testing/runtime.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2024 The Kubeflow Authors.
Copyright 2025 The Kubeflow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down

0 comments on commit be59e6b

Please sign in to comment.