diff --git a/.gitignore b/.gitignore index 5e8709701d2..34854018dab 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,6 @@ vendor # Local Netlify folder .netlify + +# Directory with CRDs for dependencies +dep-crds diff --git a/Makefile b/Makefile index b5f386e4769..d09bf3f08a5 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,9 @@ GO_TEST_FLAGS ?= -race # Use go.mod go version as a single source of truth of GO version. GO_VERSION := $(shell awk '/^go /{print $$2}' go.mod|head -n1) +# Use go.mod go version as a single source of truth of MPI version. +MPI_VERSION := $(shell awk '/mpi-operator/{print $$2}' go.mod|head -n1) + GIT_TAG ?= $(shell git describe --tags --dirty --always) # Image URL to use all building/pushing image targets PLATFORMS ?= linux/amd64,linux/arm64 @@ -141,7 +144,7 @@ test: generate fmt vet gotestsum ## Run tests. $(GOTESTSUM) --junitfile $(ARTIFACTS)/junit.xml -- $(GO_TEST_FLAGS) $(shell go list ./... | grep -v '/test/') -coverprofile $(ARTIFACTS)/cover.out .PHONY: test-integration -test-integration: manifests generate fmt vet envtest ginkgo ## Run tests. +test-integration: manifests generate fmt vet envtest ginkgo mpi-operator-crd ## Run tests. KUBEBUILDER_ASSETS="$(shell $(ENVTEST) --arch=amd64 use $(ENVTEST_K8S_VERSION) -p path)" \ $(GINKGO) --junit-report=junit.xml --output-dir=$(ARTIFACTS) -v $(INTEGRATION_TARGET) @@ -273,3 +276,8 @@ KIND = $(shell pwd)/bin/kind .PHONY: kind kind: @GOBIN=$(PROJECT_DIR)/bin GO111MODULE=on $(GO_CMD) install sigs.k8s.io/kind@v0.16.0 +.PHONY: mpi-operator-crd +mpi-operator-crd: + GOPATH=/tmp GO111MODULE=on $(GO_CMD) install github.com/kubeflow/mpi-operator/cmd/mpi-operator@$(MPI_VERSION) + mkdir -p $(shell pwd)/dep-crds/mpi-operator/ + cp -f /tmp/pkg/mod/github.com/kubeflow/mpi-operator@$(MPI_VERSION)/manifests/base/* $(shell pwd)/dep-crds/mpi-operator/ diff --git a/config/components/rbac/kustomization.yaml b/config/components/rbac/kustomization.yaml index 846ada9a690..77fba5fa2f4 100644 --- a/config/components/rbac/kustomization.yaml +++ b/config/components/rbac/kustomization.yaml @@ -29,3 +29,5 @@ resources: - workload_viewer_role.yaml - resourceflavor_editor_role.yaml - resourceflavor_viewer_role.yaml +- mpijob_editor_role.yaml +- mpijob_viewer_role.yaml diff --git a/config/components/rbac/mpijob_editor_role.yaml b/config/components/rbac/mpijob_editor_role.yaml new file mode 100644 index 00000000000..c6176e4940e --- /dev/null +++ b/config/components/rbac/mpijob_editor_role.yaml @@ -0,0 +1,27 @@ +# permissions for end users to edit jobs. +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: mpijob-editor-role + labels: + rbac.kueue.x-k8s.io/batch-admin: "true" + rbac.kueue.x-k8s.io/batch-user: "true" +rules: +- apiGroups: + - kubeflow.org + resources: + - mpijobs + verbs: + - create + - delete + - get + - list + - patch + - update + - watch +- apiGroups: + - kubeflow.org + resources: + - mpijobs/status + verbs: + - get diff --git a/config/components/rbac/mpijob_viewer_role.yaml b/config/components/rbac/mpijob_viewer_role.yaml new file mode 100644 index 00000000000..54e8d874144 --- /dev/null +++ b/config/components/rbac/mpijob_viewer_role.yaml @@ -0,0 +1,23 @@ +# permissions for end users to view jobs. +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: mpijob-viewer-role + labels: + rbac.kueue.x-k8s.io/batch-admin: "true" + rbac.kueue.x-k8s.io/batch-user: "true" +rules: +- apiGroups: + - kubeflow.org + resources: + - mpijobs + verbs: + - get + - list + - watch +- apiGroups: + - kubeflow.org + resources: + - mpijobs/status + verbs: + - get diff --git a/config/components/rbac/role.yaml b/config/components/rbac/role.yaml index c8d7bc8068a..dacb69a875a 100644 --- a/config/components/rbac/role.yaml +++ b/config/components/rbac/role.yaml @@ -72,6 +72,22 @@ rules: - jobs/status verbs: - get +- apiGroups: + - kubeflow.org + resources: + - mpijobs + verbs: + - get + - list + - patch + - update + - watch +- apiGroups: + - kubeflow.org + resources: + - mpijobs/status + verbs: + - get - apiGroups: - kueue.x-k8s.io resources: diff --git a/config/components/webhook/manifests.yaml b/config/components/webhook/manifests.yaml index 8a3074cf21c..2a2ab694ffe 100644 --- a/config/components/webhook/manifests.yaml +++ b/config/components/webhook/manifests.yaml @@ -82,6 +82,25 @@ webhooks: resources: - jobs sideEffects: None +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /mutate-kubeflow-org-v2beta1-mpijob + failurePolicy: Fail + name: mmpijob.kb.io + rules: + - apiGroups: + - kubeflow.org + apiVersions: + - v2beta1 + operations: + - CREATE + resources: + - mpijobs + sideEffects: None --- apiVersion: admissionregistration.k8s.io/v1 kind: ValidatingWebhookConfiguration @@ -188,3 +207,22 @@ webhooks: resources: - jobs sideEffects: None +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-kubeflow-org-v2beta1-mpijob + failurePolicy: Fail + name: vmpijob.kb.io + rules: + - apiGroups: + - kubeflow.org + apiVersions: + - v2beta1 + operations: + - UPDATE + resources: + - mpijobs + sideEffects: None diff --git a/go.mod b/go.mod index 641419f0573..1e173e5e17d 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,8 @@ go 1.19 require ( github.com/go-logr/logr v1.2.3 github.com/google/go-cmp v0.5.9 + github.com/kubeflow/common v0.4.6 + github.com/kubeflow/mpi-operator v0.3.1-0.20230210191002-c21942d1e27d github.com/onsi/ginkgo/v2 v2.8.3 github.com/onsi/gomega v1.27.1 github.com/open-policy-agent/cert-controller v0.7.0 @@ -26,20 +28,20 @@ require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/emicklei/go-restful/v3 v3.9.0 // indirect - github.com/evanphx/json-patch v5.6.0+incompatible // indirect + github.com/evanphx/json-patch v4.12.0+incompatible // indirect github.com/evanphx/json-patch/v5 v5.6.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-logr/zapr v1.2.3 // indirect - github.com/go-openapi/jsonpointer v0.19.5 // indirect - github.com/go-openapi/jsonreference v0.20.0 // indirect - github.com/go-openapi/swag v0.19.15 // indirect + github.com/go-openapi/jsonpointer v0.19.6 // indirect + github.com/go-openapi/jsonreference v0.20.1 // indirect + github.com/go-openapi/swag v0.22.3 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/gnostic v0.5.7-v3refs // indirect github.com/google/gofuzz v1.2.0 // indirect - github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect + github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/google/uuid v1.3.0 // indirect github.com/imdario/mergo v0.3.12 // indirect github.com/josharian/intern v1.0.0 // indirect @@ -54,8 +56,8 @@ require ( github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - go.uber.org/atomic v1.9.0 // indirect - go.uber.org/multierr v1.7.0 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.6.0 // indirect golang.org/x/net v0.7.0 // indirect golang.org/x/oauth2 v0.0.0-20220309155454-6242fa91716a // indirect golang.org/x/sys v0.5.0 // indirect @@ -70,7 +72,7 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/apiextensions-apiserver v0.26.1 // indirect - k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280 // indirect + k8s.io/kube-openapi v0.0.0-20230109183929-3758b55a6596 // indirect sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect sigs.k8s.io/yaml v1.3.0 // indirect diff --git a/go.sum b/go.sum index dc026d02330..be808f51661 100644 --- a/go.sum +++ b/go.sum @@ -67,8 +67,8 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= -github.com/evanphx/json-patch v5.6.0+incompatible h1:jBYDEEiFBPxA0v50tFdvOzQQTCvpL6mnFh5mB2/l16U= -github.com/evanphx/json-patch v5.6.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/evanphx/json-patch v4.12.0+incompatible h1:4onqiflcdA9EOZ4RxV643DvftH5pOlLGNtQ5lPWQu84= +github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.6.0 h1:b91NhWfaz02IuVxO9faSllyAtNXHMPkC5J8sJCLunww= github.com/evanphx/json-patch/v5 v5.6.0/go.mod h1:G79N1coSVB93tBe7j6PhzjmR3/2VvlbKOFpnXhI9Bw4= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= @@ -90,14 +90,12 @@ github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/zapr v1.2.3 h1:a9vnzlIBPQBBkeaR9IuMUfmVOrQlkoC4YfPoFkX3T7A= github.com/go-logr/zapr v1.2.3/go.mod h1:eIauM6P8qSvTw5o2ez6UEAfGjQKrxQTl5EoK+Qa2oG4= -github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= -github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= -github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= -github.com/go-openapi/jsonreference v0.20.0 h1:MYlu0sBgChmCfJxxUKZ8g1cPWFOB37YSZqewK7OKeyA= -github.com/go-openapi/jsonreference v0.20.0/go.mod h1:Ag74Ico3lPc+zR+qjn4XBUmXymS4zJbYVCZmcgkasdo= -github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= -github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= -github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= +github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= +github.com/go-openapi/jsonreference v0.20.1 h1:FBLnyygC4/IZZr893oiomc9XaghoveYTrLC1F86HID8= +github.com/go-openapi/jsonreference v0.20.1/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= +github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= +github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= @@ -161,8 +159,8 @@ github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= -github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -194,13 +192,16 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/kubeflow/common v0.4.6 h1:yzJf/HEdS6ginD0GlVkgbOFie0Sp66VdGjXidAGZIlk= +github.com/kubeflow/common v0.4.6/go.mod h1:43MAof/uhpJA2C0urynqatE3oKFQc7m2HLmJty7waqY= +github.com/kubeflow/mpi-operator v0.3.1-0.20230210191002-c21942d1e27d h1:KSqZfacyo92CwZHOf4/Ph6E9gOjOqc1q92Tp310lxlw= +github.com/kubeflow/mpi-operator v0.3.1-0.20230210191002-c21942d1e27d/go.mod h1:K2ijkebWk64hAaagoZ02w6QMXSkt+RmunvaXhkC/C0c= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -217,8 +218,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/onsi/ginkgo/v2 v2.8.3 h1:RpbK1G8nWPNaCVFBWsOGnEQQGgASi6b8fxcWBvDYjxQ= github.com/onsi/ginkgo/v2 v2.8.3/go.mod h1:6OaUA8BCi0aZfmzYT/q9AacwTzDpNbxILUT+TlBq6MY= github.com/onsi/gomega v1.27.1 h1:rfztXRbg6nv/5f+Raen9RcGoSecHIFgBBLQK3Wdj754= @@ -267,13 +266,17 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -283,14 +286,12 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= -go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= +go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= -go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= -go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= @@ -578,8 +579,8 @@ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLks gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= @@ -618,8 +619,8 @@ k8s.io/component-helpers v0.26.1/go.mod h1:jxNTnHb1axLe93MyVuvKj9T/+f4nxBVrj/xf0 k8s.io/klog/v2 v2.80.1 h1:atnLQ121W371wYYFawwYx1aEY2eUfs4l3J72wtgAwV4= k8s.io/klog/v2 v2.80.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= k8s.io/kube-aggregator v0.23.2 h1:6CoZZqNdFc9benrgSJJ0GQGgFtKjI0y3UwlBbioXtc8= -k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280 h1:+70TFaan3hfJzs+7VK2o+OGxg8HsuBr/5f6tVAjDu6E= -k8s.io/kube-openapi v0.0.0-20221012153701-172d655c2280/go.mod h1:+Axhij7bCpeqhklhUTe3xmOn6bWxolyZEeyaFpjGtl4= +k8s.io/kube-openapi v0.0.0-20230109183929-3758b55a6596 h1:8cNCQs+WqqnSpZ7y0LMQPKD+RZUHU17VqLPMW3qxnxc= +k8s.io/kube-openapi v0.0.0-20230109183929-3758b55a6596/go.mod h1:/BYxry62FuDzmI+i9B+X2pqfySRmSOW2ARmj5Zbqhj0= k8s.io/utils v0.0.0-20221128185143-99ec85e7a448 h1:KTgPnR10d5zhztWptI952TNtt/4u5h3IzDXkdIMuo2Y= k8s.io/utils v0.0.0-20221128185143-99ec85e7a448/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= diff --git a/main.go b/main.go index 84a86bfc817..4edd9f72b76 100644 --- a/main.go +++ b/main.go @@ -27,6 +27,7 @@ import ( // to ensure that exec-entrypoint and run can make use of them. _ "k8s.io/client-go/plugin/pkg/client/auth" + kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" zaplog "go.uber.org/zap" "go.uber.org/zap/zapcore" schedulingv1 "k8s.io/api/scheduling/v1" @@ -46,6 +47,7 @@ import ( "sigs.k8s.io/kueue/pkg/controller/core" "sigs.k8s.io/kueue/pkg/controller/core/indexer" "sigs.k8s.io/kueue/pkg/controller/workload/job" + "sigs.k8s.io/kueue/pkg/controller/workload/mpijob" "sigs.k8s.io/kueue/pkg/metrics" "sigs.k8s.io/kueue/pkg/queue" "sigs.k8s.io/kueue/pkg/scheduler" @@ -66,6 +68,7 @@ func init() { utilruntime.Must(kueue.AddToScheme(scheme)) utilruntime.Must(config.AddToScheme(scheme)) + utilruntime.Must(kubeflow.AddToScheme(scheme)) // +kubebuilder:scaffold:scheme } @@ -148,6 +151,9 @@ func setupIndexes(ctx context.Context, mgr ctrl.Manager) { if err := job.SetupIndexes(ctx, mgr.GetFieldIndexer()); err != nil { setupLog.Error(err, "Unable to setup job indexes") } + if err := mpijob.SetupIndexes(ctx, mgr.GetFieldIndexer()); err != nil { + setupLog.Error(err, "Unable to setup mpijob indexes") + } } func setupControllers(mgr ctrl.Manager, cCache *cache.Cache, queues *queue.Manager, certsReady chan struct{}, cfg *config.Configuration) { @@ -171,6 +177,15 @@ func setupControllers(mgr ctrl.Manager, cCache *cache.Cache, queues *queue.Manag setupLog.Error(err, "unable to create controller", "controller", "Job") os.Exit(1) } + if err := mpijob.NewReconciler(mgr.GetScheme(), + mgr.GetClient(), + mgr.GetEventRecorderFor(constants.KueueName+"-mpijob-controller"), + mpijob.WithManageJobsWithoutQueueName(manageJobsWithoutQueueName), + mpijob.WithWaitForPodsReady(waitForPodsReady(cfg)), + ).SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", "MPIJob") + os.Exit(1) + } if failedWebhook, err := webhooks.Setup(mgr); err != nil { setupLog.Error(err, "Unable to create webhook", "webhook", failedWebhook) os.Exit(1) @@ -179,6 +194,10 @@ func setupControllers(mgr ctrl.Manager, cCache *cache.Cache, queues *queue.Manag setupLog.Error(err, "Unable to create webhook", "webhook", "Job") os.Exit(1) } + if err := mpijob.SetupMPIJobWebhook(mgr, mpijob.WithManageJobsWithoutQueueName(manageJobsWithoutQueueName)); err != nil { + setupLog.Error(err, "Unable to create webhook", "webhook", "MPIJob") + os.Exit(1) + } // +kubebuilder:scaffold:builder } diff --git a/pkg/controller/workload/job/job_controller.go b/pkg/controller/workload/job/job_controller.go index 7aea6ff0012..a2748617496 100644 --- a/pkg/controller/workload/job/job_controller.go +++ b/pkg/controller/workload/job/job_controller.go @@ -41,6 +41,7 @@ import ( kueue "sigs.k8s.io/kueue/apis/kueue/v1alpha2" "sigs.k8s.io/kueue/pkg/constants" + "sigs.k8s.io/kueue/pkg/controller/workload/jobframework" utilpriority "sigs.k8s.io/kueue/pkg/util/priority" "sigs.k8s.io/kueue/pkg/workload" ) @@ -519,7 +520,7 @@ func ConstructWorkloadFor(ctx context.Context, client client.Client, job *batchv1.Job, scheme *runtime.Scheme) (*kueue.Workload, error) { w := &kueue.Workload{ ObjectMeta: metav1.ObjectMeta{ - Name: job.Name, + Name: GetWorkloadNameForJob(job.Name), Namespace: job.Namespace, }, Spec: kueue.WorkloadSpec{ @@ -630,3 +631,8 @@ func queueName(job *batchv1.Job) string { func parentWorkloadName(job *batchv1.Job) string { return job.Annotations[constants.ParentWorkloadAnnotation] } + +func GetWorkloadNameForJob(jobName string) string { + gvk := metav1.GroupVersionKind{Group: batchv1.SchemeGroupVersion.Group, Version: batchv1.SchemeGroupVersion.Version, Kind: "Job"} + return jobframework.GetWorkloadNameForOwnerWithGVK(jobName, &gvk) +} diff --git a/pkg/controller/workload/job/job_webhook.go b/pkg/controller/workload/job/job_webhook.go index 024819ce834..94502e491c7 100644 --- a/pkg/controller/workload/job/job_webhook.go +++ b/pkg/controller/workload/job/job_webhook.go @@ -22,6 +22,7 @@ import ( batchv1 "k8s.io/api/batch/v1" apivalidation "k8s.io/apimachinery/pkg/api/validation" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation" "k8s.io/apimachinery/pkg/util/validation/field" @@ -30,6 +31,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/kueue/pkg/constants" + "sigs.k8s.io/kueue/pkg/controller/workload/jobframework" "sigs.k8s.io/kueue/pkg/util/pointer" ) @@ -69,6 +71,17 @@ func (w *JobWebhook) Default(ctx context.Context, obj runtime.Object) error { log := ctrl.LoggerFrom(ctx).WithName("job-webhook") log.V(5).Info("Applying defaults", "job", klog.KObj(job)) + if owner := metav1.GetControllerOf(job); owner != nil && jobframework.KnownWorkloadOwner(owner) { + if job.Annotations == nil { + job.Annotations = make(map[string]string) + } + if pwName, err := jobframework.GetWorkloadNameForOwnerRef(owner); err != nil { + return err + } else { + job.Annotations[constants.ParentWorkloadAnnotation] = pwName + } + } + if queueName(job) == "" && !w.manageJobsWithoutQueueName { return nil } diff --git a/pkg/controller/workload/jobframework/known_frameworks.go b/pkg/controller/workload/jobframework/known_frameworks.go new file mode 100644 index 00000000000..ec57df4126f --- /dev/null +++ b/pkg/controller/workload/jobframework/known_frameworks.go @@ -0,0 +1,31 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jobframework + +import ( + "strings" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func KnownWorkloadOwner(owner *metav1.OwnerReference) bool { + return IsMPIJob(owner) +} + +func IsMPIJob(owner *metav1.OwnerReference) bool { + return owner.Kind == "MPIJob" && strings.HasPrefix(owner.APIVersion, "kubeflow.org/v2") +} diff --git a/pkg/controller/workload/jobframework/workload_names.go b/pkg/controller/workload/jobframework/workload_names.go new file mode 100644 index 00000000000..3750a91f27f --- /dev/null +++ b/pkg/controller/workload/jobframework/workload_names.go @@ -0,0 +1,59 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jobframework + +import ( + "crypto/sha1" + "encoding/hex" + "strings" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +const ( + hashLength = 5 + // 253 is the maximal length for a CRD name. We need to subtract one for '-', and the hash length. + maxPrefixLength = 252 - hashLength +) + +func GetWorkloadNameForOwnerRef(owner *metav1.OwnerReference) (string, error) { + gv, err := schema.ParseGroupVersion(owner.APIVersion) + if err != nil { + return "", err + } + gvk := metav1.GroupVersionKind{Group: gv.Group, Version: gv.Version, Kind: owner.Kind} + return GetWorkloadNameForOwnerWithGVK(owner.Name, &gvk), nil +} + +func GetWorkloadNameForOwnerWithGVK(ownerName string, ownerGVK *metav1.GroupVersionKind) string { + prefixedName := strings.ToLower(ownerGVK.Kind) + "-" + ownerName + if len(prefixedName) > maxPrefixLength { + prefixedName = prefixedName[:maxPrefixLength] + } + return prefixedName + "-" + getHash(ownerName, ownerGVK)[:hashLength] +} + +func getHash(ownerName string, gvk *metav1.GroupVersionKind) string { + h := sha1.New() + h.Write([]byte(gvk.Kind)) + h.Write([]byte("\n")) + h.Write([]byte(gvk.Group)) + h.Write([]byte("\n")) + h.Write([]byte(ownerName)) + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/pkg/controller/workload/jobframework/workload_names_test.go b/pkg/controller/workload/jobframework/workload_names_test.go new file mode 100644 index 00000000000..0c34e0b921b --- /dev/null +++ b/pkg/controller/workload/jobframework/workload_names_test.go @@ -0,0 +1,70 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jobframework + +import ( + "errors" + "strings" + "testing" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestGetWorkloadNameForOwner(t *testing.T) { + testCases := map[string]struct { + owner *metav1.OwnerReference + wantWorkloadName string + wantErr error + }{ + "simple name": { + owner: &metav1.OwnerReference{Kind: "Job", APIVersion: "batch/v1", Name: "myjob"}, + wantWorkloadName: "job-myjob-2f79a", + }, + "lengthy name": { + owner: &metav1.OwnerReference{Kind: "Job", APIVersion: "batch/v1", Name: strings.Repeat("a", 252)}, + wantWorkloadName: "job-" + strings.Repeat("a", 253-4-6) + "-f1c14", + }, + "simple MPIJob name for v2beta1": { + owner: &metav1.OwnerReference{Kind: "MPIJob", APIVersion: "kubeflow.org/v2beta1", Name: "myjob"}, + wantWorkloadName: "mpijob-myjob-98672", + }, + "simple MPIJob name for v2; should be as for v2beta1": { + owner: &metav1.OwnerReference{Kind: "MPIJob", APIVersion: "kubeflow.org/v2", Name: "myjob"}, + wantWorkloadName: "mpijob-myjob-98672", + }, + "invalid APIVersion": { + owner: &metav1.OwnerReference{Kind: "Job", APIVersion: "batch/v1/beta1", Name: "myjob"}, + wantErr: errors.New("unexpected GroupVersion string: batch/v1/beta1"), + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + gotWorkloadName, gotErr := GetWorkloadNameForOwnerRef(tc.owner) + if tc.wantWorkloadName != gotWorkloadName { + t.Errorf("Unexpected workload name. want=%s, got=%s", tc.wantWorkloadName, gotWorkloadName) + } + if gotErr != nil && tc.wantErr == nil { + t.Errorf("Unexpected error present: %s", gotErr.Error()) + } else if gotErr == nil && tc.wantErr != nil { + t.Errorf("Missing expected error: %s", tc.wantErr.Error()) + } else if gotErr != nil && tc.wantErr != nil && gotErr.Error() != tc.wantErr.Error() { + t.Errorf("Unexpected error. want: %s, got: %s.", tc.wantErr.Error(), gotErr.Error()) + } + }) + } +} diff --git a/pkg/controller/workload/mpijob/mpijob_controller.go b/pkg/controller/workload/mpijob/mpijob_controller.go new file mode 100644 index 00000000000..2979d0f5f62 --- /dev/null +++ b/pkg/controller/workload/mpijob/mpijob_controller.go @@ -0,0 +1,588 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mpijob + +import ( + "context" + "fmt" + "strings" + + kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + apimeta "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/tools/record" + "k8s.io/klog/v2" + "k8s.io/utils/pointer" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + + kueue "sigs.k8s.io/kueue/apis/kueue/v1alpha2" + "sigs.k8s.io/kueue/pkg/constants" + "sigs.k8s.io/kueue/pkg/controller/workload/jobframework" + utilpriority "sigs.k8s.io/kueue/pkg/util/priority" + "sigs.k8s.io/kueue/pkg/workload" +) + +var ( + ownerKey = ".metadata.mpijob_controller" +) + +// MPIJobReconciler reconciles a Job object +type MPIJobReconciler struct { + client client.Client + scheme *runtime.Scheme + record record.EventRecorder + manageJobsWithoutQueueName bool + waitForPodsReady bool +} + +type options struct { + manageJobsWithoutQueueName bool + waitForPodsReady bool +} + +// Option configures the reconciler. +type Option func(*options) + +// WithManageJobsWithoutQueueName indicates if the controller should reconcile +// jobs that don't set the queue name annotation. +func WithManageJobsWithoutQueueName(f bool) Option { + return func(o *options) { + o.manageJobsWithoutQueueName = f + } +} + +// WithWaitForPodsReady indicates if the controller should add the PodsReady +// condition to the workload when the corresponding job has all pods ready +// or succeeded. +func WithWaitForPodsReady(f bool) Option { + return func(o *options) { + o.waitForPodsReady = f + } +} + +var defaultOptions = options{} + +func NewReconciler( + scheme *runtime.Scheme, + client client.Client, + record record.EventRecorder, + opts ...Option) *MPIJobReconciler { + options := defaultOptions + for _, opt := range opts { + opt(&options) + } + + return &MPIJobReconciler{ + scheme: scheme, + client: client, + record: record, + manageJobsWithoutQueueName: options.manageJobsWithoutQueueName, + waitForPodsReady: options.waitForPodsReady, + } +} + +// SetupWithManager sets up the controller with the Manager. It indexes workloads +// based on the owning jobs. +func (r *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&kubeflow.MPIJob{}). + Owns(&kueue.Workload{}). + Complete(r) +} + +func SetupIndexes(ctx context.Context, indexer client.FieldIndexer) error { + return indexer.IndexField(ctx, &kueue.Workload{}, ownerKey, func(o client.Object) []string { + // grab the Workload object, extract the owner... + wl := o.(*kueue.Workload) + owner := metav1.GetControllerOf(wl) + if owner == nil { + return nil + } + // ...make sure it's an MPIJob... + if !jobframework.IsMPIJob(owner) { + return nil + } + // ...and if so, return it + return []string{owner.Name} + }) +} + +//+kubebuilder:rbac:groups=scheduling.k8s.io,resources=priorityclasses,verbs=list;get;watch +//+kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update +//+kubebuilder:rbac:groups=kubeflow.org,resources=mpijobs,verbs=get;list;watch;update;patch +//+kubebuilder:rbac:groups=kubeflow.org,resources=mpijobs/status,verbs=get +//+kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch;delete +//+kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch +//+kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/finalizers,verbs=update +//+kubebuilder:rbac:groups=kueue.x-k8s.io,resources=resourceflavors,verbs=get;list;watch + +func (r *MPIJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + var job kubeflow.MPIJob + if err := r.client.Get(ctx, req.NamespacedName, &job); err != nil { + // we'll ignore not-found errors, since there is nothing to do. + return ctrl.Result{}, client.IgnoreNotFound(err) + } + + log := ctrl.LoggerFrom(ctx).WithValues("mpijob", klog.KObj(&job)) + ctx = ctrl.LoggerInto(ctx, log) + + // when manageJobsWithoutQueueName is disabled we only reconcile jobs that have + // queue-name annotation set. + if !r.manageJobsWithoutQueueName && queueName(&job) == "" { + log.V(3).Info(fmt.Sprintf("%s annotation is not set, ignoring the mpijob", constants.QueueAnnotation)) + return ctrl.Result{}, nil + } + + log.V(2).Info("Reconciling MPIJob") + + // 1. make sure there is only a single existing instance of the workload + wl, err := r.ensureAtMostOneWorkload(ctx, &job) + if err != nil { + log.Error(err, "Getting existing workloads") + return ctrl.Result{}, err + } + + // 2. handle mpijob is finished. + if jobFinishedCond, jobFinished := jobFinishedCondition(&job); jobFinished { + if wl == nil || apimeta.IsStatusConditionTrue(wl.Status.Conditions, kueue.WorkloadFinished) { + return ctrl.Result{}, nil + } + condition := generateFinishedCondition(jobFinishedCond) + apimeta.SetStatusCondition(&wl.Status.Conditions, condition) + err := r.client.Status().Update(ctx, wl) + if err != nil { + log.Error(err, "Updating workload status") + } + return ctrl.Result{}, err + } + + // 3. handle workload is nil. + if wl == nil { + err := r.handleJobWithNoWorkload(ctx, &job) + if err != nil { + log.Error(err, "Handling mpijob with no workload") + } + return ctrl.Result{}, err + } + + // 4. handle WaitForPodsReady + if r.waitForPodsReady { + log.V(5).Info("Handling a mpijob when waitForPodsReady is enabled") + condition := generatePodsReadyCondition(&job, wl) + // optimization to avoid sending the update request if the status didn't change + if !apimeta.IsStatusConditionPresentAndEqual(wl.Status.Conditions, condition.Type, condition.Status) { + log.V(3).Info(fmt.Sprintf("Updating the PodsReady condition with status: %v", condition.Status)) + apimeta.SetStatusCondition(&wl.Status.Conditions, condition) + if err := r.client.Status().Update(ctx, wl); err != nil { + log.Error(err, "Updating workload status") + } + } + } + + // 5. handle mpijob is suspended. + if jobSuspended(&job) { + // start the job if the workload has been admitted, and the job is still suspended + if wl.Spec.Admission != nil { + log.V(2).Info("Job admitted, unsuspending") + err := r.startJob(ctx, wl, &job) + if err != nil { + log.Error(err, "Unsuspending job") + } + return ctrl.Result{}, err + } + + // update queue name if changed. + q := queueName(&job) + if wl.Spec.QueueName != q { + log.V(2).Info("Job changed queues, updating workload") + wl.Spec.QueueName = q + err := r.client.Update(ctx, wl) + if err != nil { + log.Error(err, "Updating workload queue") + } + return ctrl.Result{}, err + } + log.V(3).Info("Job is suspended and workload not yet admitted by a clusterQueue, nothing to do") + return ctrl.Result{}, nil + } + + // 6. handle job is unsuspended. + if wl.Spec.Admission == nil { + // the job must be suspended if the workload is not yet admitted. + log.V(2).Info("Running job is not admitted by a cluster queue, suspending") + err := r.stopJob(ctx, wl, &job, "Not admitted by cluster queue") + if err != nil { + log.Error(err, "Suspending job with non admitted workload") + } + return ctrl.Result{}, err + } + + // workload is admitted and job is running, nothing to do. + log.V(3).Info("Job running with admitted workload, nothing to do") + return ctrl.Result{}, nil +} + +// podsReady checks if all pods are ready or succeeded +func podsReady(job *kubeflow.MPIJob) bool { + for _, c := range job.Status.Conditions { + if c.Type == kubeflow.JobRunning && c.Status == corev1.ConditionTrue { + return true + } + } + return false +} + +// stopJob sends updates to suspend the job, reset the startTime so we can update the scheduling directives +// later when unsuspending and resets the nodeSelector to its previous state based on what is available in +// the workload (which should include the original affinities that the job had). +func (r *MPIJobReconciler) stopJob(ctx context.Context, w *kueue.Workload, + job *kubeflow.MPIJob, eventMsg string) error { + job.Spec.RunPolicy.Suspend = pointer.Bool(true) + if err := r.client.Update(ctx, job); err != nil { + return err + } + r.record.Eventf(job, corev1.EventTypeNormal, "Stopped", eventMsg) + + // Reset start time so we can update the scheduling directives later when unsuspending. + if job.Status.StartTime != nil { + job.Status.StartTime = nil + if err := r.client.Status().Update(ctx, job); err != nil { + return err + } + } + + if w != nil { + orderedReplicaTypes := orderedReplicaTypes(&job.Spec) + for index := range w.Spec.PodSets { + replicaType := orderedReplicaTypes[index] + if !equality.Semantic.DeepEqual(job.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector, + w.Spec.PodSets[index].Spec.NodeSelector) { + job.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = map[string]string{} + for k, v := range w.Spec.PodSets[index].Spec.NodeSelector { + job.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector[k] = v + } + } + } + return r.client.Update(ctx, job) + } + + return nil +} + +func (r *MPIJobReconciler) startJob(ctx context.Context, w *kueue.Workload, job *kubeflow.MPIJob) error { + log := ctrl.LoggerFrom(ctx) + + orderedReplicaTypes := orderedReplicaTypes(&job.Spec) + for index := range w.Spec.PodSets { + replicaType := orderedReplicaTypes[index] + nodeSelector, err := r.getNodeSelectors(ctx, w, index) + if err != nil { + return err + } + if len(nodeSelector) != 0 { + if job.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector == nil { + job.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector = nodeSelector + } else { + for k, v := range nodeSelector { + job.Spec.MPIReplicaSpecs[replicaType].Template.Spec.NodeSelector[k] = v + } + } + } else { + log.V(3).Info("no nodeSelectors to inject") + } + } + + job.Spec.RunPolicy.Suspend = pointer.Bool(false) + if err := r.client.Update(ctx, job); err != nil { + return err + } + + r.record.Eventf(job, corev1.EventTypeNormal, "Started", + "Admitted by clusterQueue %v", w.Spec.Admission.ClusterQueue) + return nil +} + +func (r *MPIJobReconciler) getNodeSelectors(ctx context.Context, w *kueue.Workload, index int) (map[string]string, error) { + if len(w.Spec.Admission.PodSetFlavors[index].Flavors) == 0 { + return nil, nil + } + + processedFlvs := sets.NewString() + nodeSelector := map[string]string{} + for _, flvName := range w.Spec.Admission.PodSetFlavors[index].Flavors { + if processedFlvs.Has(flvName) { + continue + } + // Lookup the ResourceFlavors to fetch the node affinity labels to apply on the job. + flv := kueue.ResourceFlavor{} + if err := r.client.Get(ctx, types.NamespacedName{Name: flvName}, &flv); err != nil { + return nil, err + } + for k, v := range flv.NodeSelector { + nodeSelector[k] = v + } + processedFlvs.Insert(flvName) + } + return nodeSelector, nil +} + +func (r *MPIJobReconciler) handleJobWithNoWorkload(ctx context.Context, job *kubeflow.MPIJob) error { + log := ctrl.LoggerFrom(ctx) + + // Wait until there are no active pods. + for _, replicaStatus := range job.Status.ReplicaStatuses { + if replicaStatus.Active != 0 { + log.V(2).Info("Job is suspended but still has active pods, waiting") + return nil + } + } + + // Create the corresponding workload. + wl, err := ConstructWorkloadFor(ctx, r.client, job, r.scheme) + if err != nil { + return err + } + if err = r.client.Create(ctx, wl); err != nil { + return err + } + + r.record.Eventf(job, corev1.EventTypeNormal, "CreatedWorkload", + "Created Workload: %v", workload.Key(wl)) + return nil +} + +// ensureAtMostOneWorkload finds a matching workload and deletes redundant ones. +func (r *MPIJobReconciler) ensureAtMostOneWorkload(ctx context.Context, job *kubeflow.MPIJob) (*kueue.Workload, error) { + log := ctrl.LoggerFrom(ctx) + + // Find a matching workload first if there is one. + var toDelete []*kueue.Workload + var match *kueue.Workload + + var workloads kueue.WorkloadList + if err := r.client.List(ctx, &workloads, client.InNamespace(job.Namespace), + client.MatchingFields{ownerKey: job.Name}); err != nil { + log.Error(err, "Unable to list child workloads") + return nil, err + } + + for i := range workloads.Items { + w := &workloads.Items[i] + owner := metav1.GetControllerOf(w) + // Indexes don't work in unit tests, so we explicitly check for the + // owner here. + if owner.Name != job.Name { + continue + } + if match == nil && jobAndWorkloadEqual(job, w) { + match = w + } else { + toDelete = append(toDelete, w) + } + } + + // If there is no matching workload and the job is running, suspend it. + if match == nil && !jobSuspended(job) { + log.V(2).Info("job with no matching workload, suspending") + var w *kueue.Workload + if len(workloads.Items) == 1 { + // The job may have been modified and hence the existing workload + // doesn't match the job anymore. All bets are off if there are more + // than one workload... + w = &workloads.Items[0] + } + if err := r.stopJob(ctx, w, job, "No matching Workload"); err != nil { + log.Error(err, "stopping job") + } + } + + // Delete duplicate workload instances. + existedWls := 0 + for i := range toDelete { + err := r.client.Delete(ctx, toDelete[i]) + if err == nil || !apierrors.IsNotFound(err) { + existedWls++ + } + if err != nil && !apierrors.IsNotFound(err) { + log.Error(err, "Failed to delete workload") + } + if err == nil { + r.record.Eventf(job, corev1.EventTypeNormal, "DeletedWorkload", + "Deleted not matching Workload: %v", workload.Key(toDelete[i])) + } + } + + if existedWls != 0 { + if match == nil { + return nil, fmt.Errorf("no matching workload was found, tried deleting %d existing workload(s)", existedWls) + } + return nil, fmt.Errorf("only one workload should exist, found %d", len(workloads.Items)) + } + + return match, nil +} + +func ConstructWorkloadFor(ctx context.Context, client client.Client, + job *kubeflow.MPIJob, scheme *runtime.Scheme) (*kueue.Workload, error) { + w := &kueue.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: GetWorkloadNameForMPIJob(job.Name), + Namespace: job.Namespace, + }, + Spec: kueue.WorkloadSpec{ + QueueName: queueName(job), + }, + } + + for _, mpiReplicaType := range orderedReplicaTypes(&job.Spec) { + podSet := kueue.PodSet{ + Name: strings.ToLower(string(mpiReplicaType)), + Spec: *job.Spec.MPIReplicaSpecs[mpiReplicaType].Template.Spec.DeepCopy(), + Count: podsCount(&job.Spec, mpiReplicaType), + } + w.Spec.PodSets = append(w.Spec.PodSets, podSet) + } + + // Populate priority from priority class. + selectedPriorityClassName := calcPriorityClassName(job) + priorityClassName, p, err := utilpriority.GetPriorityFromPriorityClass( + ctx, client, selectedPriorityClassName) + if err != nil { + return nil, err + } + w.Spec.Priority = &p + w.Spec.PriorityClassName = priorityClassName + + if err := ctrl.SetControllerReference(job, w, scheme); err != nil { + return nil, err + } + + return w, nil +} + +// calcPriorityClassName calculates the priorityClass name needed for workload according to the following priorities: +// 1. .spec.runPolicy.schedulingPolicy.priorityClass +// 2. .spec.mpiReplicaSecs[Launcher].template.spec.priorityClassName +// 3. .spec.mpiReplicaSecs[Worker].template.spec.priorityClassName +func calcPriorityClassName(job *kubeflow.MPIJob) string { + if job.Spec.RunPolicy.SchedulingPolicy != nil && len(job.Spec.RunPolicy.SchedulingPolicy.PriorityClass) != 0 { + return job.Spec.RunPolicy.SchedulingPolicy.PriorityClass + } else if l := job.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher]; l != nil && len(l.Template.Spec.PriorityClassName) != 0 { + return l.Template.Spec.PriorityClassName + } else if w := job.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker]; w != nil && len(w.Template.Spec.PriorityClassName) != 0 { + return w.Template.Spec.PriorityClassName + } + return "" +} + +func orderedReplicaTypes(jobSpec *kubeflow.MPIJobSpec) []kubeflow.MPIReplicaType { + var result []kubeflow.MPIReplicaType + if _, ok := jobSpec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher]; ok { + result = append(result, kubeflow.MPIReplicaTypeLauncher) + } + if _, ok := jobSpec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker]; ok { + result = append(result, kubeflow.MPIReplicaTypeWorker) + } + return result +} + +func podsCount(jobSpec *kubeflow.MPIJobSpec, mpiReplicaType kubeflow.MPIReplicaType) int32 { + return pointer.Int32Deref(jobSpec.MPIReplicaSpecs[mpiReplicaType].Replicas, 1) +} + +func generatePodsReadyCondition(job *kubeflow.MPIJob, wl *kueue.Workload) metav1.Condition { + conditionStatus := metav1.ConditionFalse + message := "Not all pods are ready or succeeded" + // Once PodsReady=True it stays as long as the workload remains admitted to + // avoid unnecessary flickering the the condition. + if wl.Spec.Admission != nil && (podsReady(job) || apimeta.IsStatusConditionTrue(wl.Status.Conditions, kueue.WorkloadPodsReady)) { + conditionStatus = metav1.ConditionTrue + message = "All pods were ready or succeeded since the workload admission" + } + return metav1.Condition{ + Type: kueue.WorkloadPodsReady, + Status: conditionStatus, + Reason: "PodsReady", + Message: message, + } +} + +func generateFinishedCondition(jobStatus kubeflow.JobConditionType) metav1.Condition { + message := "Job finished successfully" + if jobStatus == kubeflow.JobFailed { + message = "Job failed" + } + return metav1.Condition{ + Type: kueue.WorkloadFinished, + Status: metav1.ConditionTrue, + Reason: "JobFinished", + Message: message, + } +} + +// From https://github.com/kubernetes/kubernetes/blob/master/pkg/controller/job/utils.go +func jobFinishedCondition(j *kubeflow.MPIJob) (kubeflow.JobConditionType, bool) { + for _, c := range j.Status.Conditions { + if (c.Type == kubeflow.JobSucceeded || c.Type == kubeflow.JobFailed) && c.Status == corev1.ConditionTrue { + return c.Type, true + } + } + return "", false +} + +func jobSuspended(j *kubeflow.MPIJob) bool { + return j.Spec.RunPolicy.Suspend != nil && *j.Spec.RunPolicy.Suspend +} + +func jobAndWorkloadEqual(job *kubeflow.MPIJob, wl *kueue.Workload) bool { + if len(wl.Spec.PodSets) != len(job.Spec.MPIReplicaSpecs) { + return false + } + for index, mpiReplicaType := range orderedReplicaTypes(&job.Spec) { + mpiReplicaSpec := job.Spec.MPIReplicaSpecs[mpiReplicaType] + if pointer.Int32Deref(mpiReplicaSpec.Replicas, 1) != wl.Spec.PodSets[index].Count { + return false + } + // nodeSelector may change, hence we are not checking for + // equality of the whole job.Spec.Template.Spec. + if !equality.Semantic.DeepEqual(mpiReplicaSpec.Template.Spec.InitContainers, + wl.Spec.PodSets[index].Spec.InitContainers) { + return false + } + if !equality.Semantic.DeepEqual(mpiReplicaSpec.Template.Spec.Containers, + wl.Spec.PodSets[index].Spec.Containers) { + return false + } + } + return true +} + +func queueName(job *kubeflow.MPIJob) string { + return job.Annotations[constants.QueueAnnotation] +} + +func GetWorkloadNameForMPIJob(jobName string) string { + gvk := metav1.GroupVersionKind{Group: kubeflow.SchemeGroupVersion.Group, Version: kubeflow.SchemeGroupVersion.Version, Kind: kubeflow.SchemeGroupVersionKind.Kind} + return jobframework.GetWorkloadNameForOwnerWithGVK(jobName, &gvk) +} diff --git a/pkg/controller/workload/mpijob/mpijob_controller_test.go b/pkg/controller/workload/mpijob/mpijob_controller_test.go new file mode 100644 index 00000000000..8a7468a6b72 --- /dev/null +++ b/pkg/controller/workload/mpijob/mpijob_controller_test.go @@ -0,0 +1,169 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mpijob + +import ( + "testing" + + common "github.com/kubeflow/common/pkg/apis/common/v1" + kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" + v1 "k8s.io/api/core/v1" +) + +func TestCalcPriorityClassName(t *testing.T) { + testcases := map[string]struct { + job kubeflow.MPIJob + wantPriorityClassName string + }{ + "none priority class name specified": { + job: kubeflow.MPIJob{}, + wantPriorityClassName: "", + }, + "priority specified at runPolicy and replicas; use priority in runPolicy": { + job: kubeflow.MPIJob{ + Spec: kubeflow.MPIJobSpec{ + RunPolicy: kubeflow.RunPolicy{ + SchedulingPolicy: &kubeflow.SchedulingPolicy{ + PriorityClass: "scheduling-priority", + }, + }, + MPIReplicaSpecs: map[kubeflow.MPIReplicaType]*common.ReplicaSpec{ + kubeflow.MPIReplicaTypeLauncher: { + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + PriorityClassName: "launcher-priority", + }, + }, + }, + kubeflow.MPIReplicaTypeWorker: { + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + PriorityClassName: "worker-priority", + }, + }, + }, + }, + }, + }, + wantPriorityClassName: "scheduling-priority", + }, + "runPolicy present, but without priority; fallback to launcher": { + job: kubeflow.MPIJob{ + Spec: kubeflow.MPIJobSpec{ + RunPolicy: kubeflow.RunPolicy{ + SchedulingPolicy: &kubeflow.SchedulingPolicy{}, + }, + MPIReplicaSpecs: map[kubeflow.MPIReplicaType]*common.ReplicaSpec{ + kubeflow.MPIReplicaTypeLauncher: { + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + PriorityClassName: "launcher-priority", + }, + }, + }, + }, + }, + }, + wantPriorityClassName: "launcher-priority", + }, + "specified on launcher takes precedence over worker": { + job: kubeflow.MPIJob{ + Spec: kubeflow.MPIJobSpec{ + MPIReplicaSpecs: map[kubeflow.MPIReplicaType]*common.ReplicaSpec{ + kubeflow.MPIReplicaTypeLauncher: { + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + PriorityClassName: "launcher-priority", + }, + }, + }, + kubeflow.MPIReplicaTypeWorker: { + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + PriorityClassName: "worker-priority", + }, + }, + }, + }, + }, + }, + wantPriorityClassName: "launcher-priority", + }, + "launcher present, but without priority; fallback to worker": { + job: kubeflow.MPIJob{ + Spec: kubeflow.MPIJobSpec{ + MPIReplicaSpecs: map[kubeflow.MPIReplicaType]*common.ReplicaSpec{ + kubeflow.MPIReplicaTypeLauncher: { + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{}, + }, + }, + kubeflow.MPIReplicaTypeWorker: { + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + PriorityClassName: "worker-priority", + }, + }, + }, + }, + }, + }, + wantPriorityClassName: "worker-priority", + }, + "specified on worker only": { + job: kubeflow.MPIJob{ + Spec: kubeflow.MPIJobSpec{ + MPIReplicaSpecs: map[kubeflow.MPIReplicaType]*common.ReplicaSpec{ + kubeflow.MPIReplicaTypeLauncher: {}, + kubeflow.MPIReplicaTypeWorker: { + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + PriorityClassName: "worker-priority", + }, + }, + }, + }, + }, + }, + wantPriorityClassName: "worker-priority", + }, + "worker present, but without priority; fallback to empty": { + job: kubeflow.MPIJob{ + Spec: kubeflow.MPIJobSpec{ + MPIReplicaSpecs: map[kubeflow.MPIReplicaType]*common.ReplicaSpec{ + kubeflow.MPIReplicaTypeLauncher: {}, + kubeflow.MPIReplicaTypeWorker: { + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{}, + }, + }, + }, + }, + }, + wantPriorityClassName: "", + }, + } + + for name, tc := range testcases { + t.Run(name, func(t *testing.T) { + gotPriorityClassName := calcPriorityClassName(&tc.job) + if tc.wantPriorityClassName != gotPriorityClassName { + t.Errorf("Unexpected response (want: %v, got: %v)", tc.wantPriorityClassName, gotPriorityClassName) + } + }) + } +} diff --git a/pkg/controller/workload/mpijob/mpijob_webhook.go b/pkg/controller/workload/mpijob/mpijob_webhook.go new file mode 100644 index 00000000000..c9bf1876958 --- /dev/null +++ b/pkg/controller/workload/mpijob/mpijob_webhook.go @@ -0,0 +1,125 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mpijob + +import ( + "context" + "strings" + + kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook" + + "sigs.k8s.io/kueue/pkg/constants" + "sigs.k8s.io/kueue/pkg/util/pointer" +) + +type MPIJobWebhook struct { + manageJobsWithoutQueueName bool +} + +// SetupWebhook configures the webhook for kubeflow MPIJob. +func SetupMPIJobWebhook(mgr ctrl.Manager, opts ...Option) error { + options := defaultOptions + for _, opt := range opts { + opt(&options) + } + wh := &MPIJobWebhook{ + manageJobsWithoutQueueName: options.manageJobsWithoutQueueName, + } + return ctrl.NewWebhookManagedBy(mgr). + For(&kubeflow.MPIJob{}). + WithDefaulter(wh). + WithValidator(wh). + Complete() +} + +// +kubebuilder:webhook:path=/mutate-kubeflow-org-v2beta1-mpijob,mutating=true,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=mpijobs,verbs=create,versions=v2beta1,name=mmpijob.kb.io,admissionReviewVersions=v1 + +var _ webhook.CustomDefaulter = &MPIJobWebhook{} + +var ( + annotationsPath = field.NewPath("metadata", "annotations") + suspendPath = field.NewPath("spec", "runPolicy", "suspend") +) + +// Default implements webhook.CustomDefaulter so a webhook will be registered for the type +func (w *MPIJobWebhook) Default(ctx context.Context, obj runtime.Object) error { + job := obj.(*kubeflow.MPIJob) + log := ctrl.LoggerFrom(ctx).WithName("job-webhook") + log.V(5).Info("Applying defaults", "job", klog.KObj(job)) + + if queueName(job) == "" && !w.manageJobsWithoutQueueName { + return nil + } + + if !(*job.Spec.RunPolicy.Suspend) { + job.Spec.RunPolicy.Suspend = pointer.Bool(true) + } + return nil +} + +// +kubebuilder:webhook:path=/validate-kubeflow-org-v2beta1-mpijob,mutating=false,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=mpijobs,verbs=update,versions=v2beta1,name=vmpijob.kb.io,admissionReviewVersions=v1 + +var _ webhook.CustomValidator = &MPIJobWebhook{} + +// ValidateCreate implements webhook.CustomValidator so a webhook will be registered for the type +func (w *MPIJobWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) error { + job := obj.(*kubeflow.MPIJob) + return validateCreate(job).ToAggregate() +} + +func validateCreate(job *kubeflow.MPIJob) field.ErrorList { + klog.InfoS("validateCreate invoked", "mpijob", job.Name) + var allErrs field.ErrorList + for _, crdNameAnnotation := range []string{constants.ParentWorkloadAnnotation, constants.QueueAnnotation} { + if value, exists := job.Annotations[crdNameAnnotation]; exists { + if errs := validation.IsDNS1123Subdomain(value); len(errs) > 0 { + allErrs = append(allErrs, field.Invalid(annotationsPath.Key(crdNameAnnotation), value, strings.Join(errs, ","))) + } + } + } + return allErrs +} + +// ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type +func (w *MPIJobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) error { + oldJob := oldObj.(*kubeflow.MPIJob) + newJob := newObj.(*kubeflow.MPIJob) + log := ctrl.LoggerFrom(ctx).WithName("job-webhook") + log.Info("Validating update", "mpijob", klog.KObj(newJob)) + return validateUpdate(oldJob, newJob).ToAggregate() +} + +func validateUpdate(oldJob, newJob *kubeflow.MPIJob) field.ErrorList { + allErrs := validateCreate(newJob) + + if !*newJob.Spec.RunPolicy.Suspend && (queueName(oldJob) != queueName(newJob)) { + allErrs = append(allErrs, field.Forbidden(suspendPath, "must not update queue name when job is unsuspend")) + } + + return allErrs +} + +// ValidateDelete implements webhook.CustomValidator so a webhook will be registered for the type +func (w *MPIJobWebhook) ValidateDelete(ctx context.Context, obj runtime.Object) error { + return nil +} diff --git a/pkg/util/testing/wrappers_mpijob.go b/pkg/util/testing/wrappers_mpijob.go new file mode 100644 index 00000000000..d5fb3d4d3d7 --- /dev/null +++ b/pkg/util/testing/wrappers_mpijob.go @@ -0,0 +1,115 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + common "github.com/kubeflow/common/pkg/apis/common/v1" + kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "sigs.k8s.io/kueue/pkg/constants" + "sigs.k8s.io/kueue/pkg/util/pointer" +) + +// JobWrapper wraps a Job. +type MPIJobWrapper struct{ kubeflow.MPIJob } + +// MakeJob creates a wrapper for a suspended job with a single container and parallelism=1. +func MakeMPIJob(name, ns string) *MPIJobWrapper { + return &MPIJobWrapper{kubeflow.MPIJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: ns, + Annotations: make(map[string]string, 1), + }, + Spec: kubeflow.MPIJobSpec{ + RunPolicy: kubeflow.RunPolicy{ + Suspend: pointer.Bool(true), + }, + MPIReplicaSpecs: map[kubeflow.MPIReplicaType]*common.ReplicaSpec{ + kubeflow.MPIReplicaTypeLauncher: { + Replicas: pointer.Int32(1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + RestartPolicy: "Never", + Containers: []corev1.Container{ + { + Name: "c", + Image: "pause", + Command: []string{}, + Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}}, + }, + }, + NodeSelector: map[string]string{}, + }, + }, + }, + kubeflow.MPIReplicaTypeWorker: { + Replicas: pointer.Int32(1), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + RestartPolicy: "Never", + Containers: []corev1.Container{ + { + Name: "c", + Image: "pause", + Command: []string{}, + Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}}, + }, + }, + NodeSelector: map[string]string{}, + }, + }, + }, + }, + }, + }} +} + +// PriorityClass updates job priorityclass. +func (j *MPIJobWrapper) PriorityClass(pc string) *MPIJobWrapper { + if j.Spec.RunPolicy.SchedulingPolicy == nil { + j.Spec.RunPolicy.SchedulingPolicy = &kubeflow.SchedulingPolicy{} + } + j.Spec.RunPolicy.SchedulingPolicy.PriorityClass = pc + return j +} + +// Obj returns the inner Job. +func (j *MPIJobWrapper) Obj() *kubeflow.MPIJob { + return &j.MPIJob +} + +// Queue updates the queue name of the job +func (j *MPIJobWrapper) Queue(queue string) *MPIJobWrapper { + j.Annotations[constants.QueueAnnotation] = queue + return j +} + +// Request adds a resource request to the default container. +func (j *MPIJobWrapper) Request(replicaType kubeflow.MPIReplicaType, r corev1.ResourceName, v string) *MPIJobWrapper { + j.Spec.MPIReplicaSpecs[replicaType].Template.Spec.Containers[0].Resources.Requests[r] = resource.MustParse(v) + return j +} + +// Parallelism updates job parallelism. +func (j *MPIJobWrapper) Parallelism(p int32) *MPIJobWrapper { + j.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Replicas = pointer.Int32(p) + return j +} diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index 320b82934ec..b1236454203 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -26,6 +26,7 @@ import ( "k8s.io/apimachinery/pkg/types" kueue "sigs.k8s.io/kueue/apis/kueue/v1alpha2" + workloadjob "sigs.k8s.io/kueue/pkg/controller/workload/job" "sigs.k8s.io/kueue/pkg/util/testing" "sigs.k8s.io/kueue/test/util" ) @@ -35,6 +36,7 @@ import ( var _ = ginkgo.Describe("Kueue", func() { var ns *corev1.Namespace var sampleJob *batchv1.Job + ginkgo.BeforeEach(func() { ns = &corev1.Namespace{ ObjectMeta: metav1.ObjectMeta{ @@ -60,9 +62,10 @@ var _ = ginkgo.Describe("Kueue", func() { } return *createdJob.Spec.Suspend }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + wlLookupKey := types.NamespacedName{Name: workloadjob.GetWorkloadNameForJob(lookupKey.Name), Namespace: ns.Name} createdWorkload := &kueue.Workload{} gomega.Eventually(func() bool { - if err := k8sClient.Get(ctx, lookupKey, createdWorkload); err != nil { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { return false } return apimeta.IsStatusConditionTrue(createdWorkload.Status.Conditions, kueue.WorkloadAdmitted) @@ -96,9 +99,9 @@ var _ = ginkgo.Describe("Kueue", func() { util.ExpectResourceFlavorToBeDeleted(ctx, k8sClient, resourceKueue, true) }) ginkgo.It("Should unsuspend a job", func() { - lookupKey := types.NamespacedName{Name: "test-job", Namespace: ns.Name} createdJob := &batchv1.Job{} createdWorkload := &kueue.Workload{} + lookupKey := types.NamespacedName{Name: "test-job", Namespace: ns.Name} gomega.Eventually(func() bool { if err := k8sClient.Get(ctx, lookupKey, createdJob); err != nil { @@ -106,8 +109,9 @@ var _ = ginkgo.Describe("Kueue", func() { } return !*createdJob.Spec.Suspend && createdJob.Status.Succeeded > 0 }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + wlLookupKey := types.NamespacedName{Name: workloadjob.GetWorkloadNameForJob(lookupKey.Name), Namespace: ns.Name} gomega.Eventually(func() bool { - if err := k8sClient.Get(ctx, lookupKey, createdWorkload); err != nil { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { return false } return apimeta.IsStatusConditionTrue(createdWorkload.Status.Conditions, kueue.WorkloadAdmitted) && diff --git a/test/integration/controller/job/job_controller_test.go b/test/integration/controller/job/job_controller_test.go index 3b04f2e9bee..8485382391c 100644 --- a/test/integration/controller/job/job_controller_test.go +++ b/test/integration/controller/job/job_controller_test.go @@ -44,12 +44,15 @@ const ( parallelism = 4 jobName = "test-job" jobNamespace = "default" - jobKey = jobNamespace + "/" + jobName labelKey = "cloud.provider.com/instance" priorityClassName = "test-priority-class" priorityValue = 10 ) +var ( + wlLookupKey = types.NamespacedName{Name: workloadjob.GetWorkloadNameForJob(jobName), Namespace: jobNamespace} +) + var ignoreConditionTimestamps = cmpopts.IgnoreFields(metav1.Condition{}, "LastTransitionTime") // +kubebuilder:docs-gen:collapse=Imports @@ -85,7 +88,7 @@ var _ = ginkgo.Describe("Job controller", func() { ginkgo.By("checking the workload is created without queue assigned") createdWorkload := &kueue.Workload{} gomega.Eventually(func() bool { - err := k8sClient.Get(ctx, lookupKey, createdWorkload) + err := k8sClient.Get(ctx, wlLookupKey, createdWorkload) return err == nil }, util.Timeout, util.Interval).Should(gomega.BeTrue()) gomega.Expect(createdWorkload.Spec.QueueName).Should(gomega.Equal(""), "The Workload shouldn't have .spec.queueName set") @@ -100,7 +103,7 @@ var _ = ginkgo.Describe("Job controller", func() { createdJob.Annotations = map[string]string{constants.QueueAnnotation: jobQueueName} gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed()) gomega.Eventually(func() bool { - if err := k8sClient.Get(ctx, lookupKey, createdWorkload); err != nil { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { return false } return createdWorkload.Spec.QueueName == jobQueueName @@ -108,7 +111,7 @@ var _ = ginkgo.Describe("Job controller", func() { ginkgo.By("checking a second non-matching workload is deleted") secondWl, _ := workloadjob.ConstructWorkloadFor(ctx, k8sClient, createdJob, scheme.Scheme) - secondWl.Name = "second-workload" + secondWl.Name = workloadjob.GetWorkloadNameForJob("second-workload") secondWl.Spec.PodSets[0].Count = parallelism + 1 gomega.Expect(k8sClient.Create(ctx, secondWl)).Should(gomega.Succeed()) gomega.Eventually(func() error { @@ -118,7 +121,7 @@ var _ = ginkgo.Describe("Job controller", func() { }, util.Timeout, util.Interval).Should(testing.BeNotFoundError()) // check the original wl is still there gomega.Consistently(func() bool { - err := k8sClient.Get(ctx, lookupKey, createdWorkload) + err := k8sClient.Get(ctx, wlLookupKey, createdWorkload) return err == nil }, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue()) gomega.Eventually(func() bool { @@ -158,7 +161,7 @@ var _ = ginkgo.Describe("Job controller", func() { gomega.Expect(len(createdJob.Spec.Template.Spec.NodeSelector)).Should(gomega.Equal(1)) gomega.Expect(createdJob.Spec.Template.Spec.NodeSelector[labelKey]).Should(gomega.Equal(onDemandFlavor.Name)) gomega.Consistently(func() bool { - if err := k8sClient.Get(ctx, lookupKey, createdWorkload); err != nil { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { return false } return len(createdWorkload.Status.Conditions) == 0 @@ -176,13 +179,13 @@ var _ = ginkgo.Describe("Job controller", func() { len(createdJob.Spec.Template.Spec.NodeSelector) == 0 }, util.Timeout, util.Interval).Should(gomega.BeTrue()) gomega.Eventually(func() bool { - ok, _ := testing.CheckLatestEvent(ctx, k8sClient, "DeletedWorkload", corev1.EventTypeNormal, fmt.Sprintf("Deleted not matching Workload: %v", jobKey)) + ok, _ := testing.CheckLatestEvent(ctx, k8sClient, "DeletedWorkload", corev1.EventTypeNormal, fmt.Sprintf("Deleted not matching Workload: %v", wlLookupKey.String())) return ok }, util.Timeout, util.Interval).Should(gomega.BeTrue()) ginkgo.By("checking the workload is updated with new count") gomega.Eventually(func() bool { - if err := k8sClient.Get(ctx, lookupKey, createdWorkload); err != nil { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { return false } return createdWorkload.Spec.PodSets[0].Count == newParallelism @@ -208,7 +211,7 @@ var _ = ginkgo.Describe("Job controller", func() { gomega.Expect(len(createdJob.Spec.Template.Spec.NodeSelector)).Should(gomega.Equal(1)) gomega.Expect(createdJob.Spec.Template.Spec.NodeSelector[labelKey]).Should(gomega.Equal(spotFlavor.Name)) gomega.Consistently(func() bool { - if err := k8sClient.Get(ctx, lookupKey, createdWorkload); err != nil { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { return false } return len(createdWorkload.Status.Conditions) == 0 @@ -224,7 +227,7 @@ var _ = ginkgo.Describe("Job controller", func() { }) gomega.Expect(k8sClient.Status().Update(ctx, createdJob)).Should(gomega.Succeed()) gomega.Eventually(func() bool { - err := k8sClient.Get(ctx, lookupKey, createdWorkload) + err := k8sClient.Get(ctx, wlLookupKey, createdWorkload) if err != nil || len(createdWorkload.Status.Conditions) == 0 { return false } @@ -236,10 +239,11 @@ var _ = ginkgo.Describe("Job controller", func() { ginkgo.When("The parent-workload annotation is used", func() { var ( - parentJobName = jobName + "-parent" - parentLookupKey = types.NamespacedName{Name: parentJobName, Namespace: jobNamespace} - childJobName = jobName + "-child" - childLookupKey = types.NamespacedName{Name: childJobName, Namespace: jobNamespace} + parentJobName = jobName + "-parent" + parentWlLookupKey = types.NamespacedName{Name: workloadjob.GetWorkloadNameForJob(parentJobName), Namespace: jobNamespace} + childJobName = jobName + "-child" + childLookupKey = types.NamespacedName{Name: childJobName, Namespace: jobNamespace} + childWlLookupKey = types.NamespacedName{Name: workloadjob.GetWorkloadNameForJob(childJobName), Namespace: jobNamespace} ) ginkgo.It("Should suspend a job if the parent workload does not exist", func() { @@ -262,7 +266,7 @@ var _ = ginkgo.Describe("Job controller", func() { ginkgo.By("waiting for the parent workload to be created") parentWorkload := &kueue.Workload{} gomega.Eventually(func() error { - return k8sClient.Get(ctx, parentLookupKey, parentWorkload) + return k8sClient.Get(ctx, parentWlLookupKey, parentWorkload) }, util.Timeout, util.Interval).Should(gomega.Succeed()) ginkgo.By("Creating the child job which uses the parent workload annotation") @@ -272,7 +276,7 @@ var _ = ginkgo.Describe("Job controller", func() { ginkgo.By("Checking that the child workload is not created") childWorkload := &kueue.Workload{} gomega.Consistently(func() bool { - return apierrors.IsNotFound(k8sClient.Get(ctx, childLookupKey, childWorkload)) + return apierrors.IsNotFound(k8sClient.Get(ctx, childWlLookupKey, childWorkload)) }, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue()) }) @@ -286,13 +290,13 @@ var _ = ginkgo.Describe("Job controller", func() { gomega.Expect(k8sClient.Create(ctx, parentJob)).Should(gomega.Succeed()) ginkgo.By("Creating the child job with the parent-workload annotation") - childJob := testing.MakeJob(childJobName, jobNamespace).ParentWorkload(parentJobName).Obj() + childJob := testing.MakeJob(childJobName, jobNamespace).ParentWorkload(parentWlLookupKey.Name).Obj() gomega.Expect(k8sClient.Create(ctx, childJob)).Should(gomega.Succeed()) ginkgo.By("waiting for the parent workload to be created") parentWorkload := &kueue.Workload{} gomega.Eventually(func() error { - return k8sClient.Get(ctx, parentLookupKey, parentWorkload) + return k8sClient.Get(ctx, parentWlLookupKey, parentWorkload) }, util.Timeout, util.Interval).Should(gomega.Succeed()) ginkgo.By("admit the parent workload") @@ -316,7 +320,7 @@ var _ = ginkgo.Describe("Job controller", func() { ginkgo.By("Unset admission of the workload to suspend the job") gomega.Eventually(func() error { - if err := k8sClient.Get(ctx, parentLookupKey, parentWorkload); err != nil { + if err := k8sClient.Get(ctx, parentWlLookupKey, parentWorkload); err != nil { return err } parentWorkload.Spec.Admission = nil @@ -353,7 +357,7 @@ var _ = ginkgo.Describe("Job controller for workloads with no queue set", func() createdWorkload := &kueue.Workload{} gomega.Consistently(func() bool { - return apierrors.IsNotFound(k8sClient.Get(ctx, lookupKey, createdWorkload)) + return apierrors.IsNotFound(k8sClient.Get(ctx, wlLookupKey, createdWorkload)) }, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue()) ginkgo.By("checking the workload is created when queue name is set") @@ -361,7 +365,7 @@ var _ = ginkgo.Describe("Job controller for workloads with no queue set", func() createdJob.Annotations = map[string]string{constants.QueueAnnotation: jobQueueName} gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed()) gomega.Eventually(func() error { - return k8sClient.Get(ctx, lookupKey, createdWorkload) + return k8sClient.Get(ctx, wlLookupKey, createdWorkload) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) ginkgo.When("The parent-workload annotation is used", func() { @@ -424,7 +428,7 @@ var _ = ginkgo.Describe("Job controller when waitForPodsReady enabled", func() { ginkgo.By("Fetch the workload created for the job") createdWorkload := &kueue.Workload{} gomega.Eventually(func() error { - return k8sClient.Get(ctx, lookupKey, createdWorkload) + return k8sClient.Get(ctx, wlLookupKey, createdWorkload) }, util.Timeout, util.Interval).Should(gomega.Succeed()) ginkgo.By("Admit the workload created for the job") @@ -437,7 +441,7 @@ var _ = ginkgo.Describe("Job controller when waitForPodsReady enabled", func() { }}, } gomega.Expect(k8sClient.Update(ctx, createdWorkload)).Should(gomega.Succeed()) - gomega.Expect(k8sClient.Get(ctx, lookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).Should(gomega.Succeed()) ginkgo.By("Await for the job to be unsuspended") gomega.Eventually(func() *bool { @@ -455,7 +459,7 @@ var _ = ginkgo.Describe("Job controller when waitForPodsReady enabled", func() { if podsReadyTestSpec.beforeCondition != nil { ginkgo.By("Update the workload status") gomega.Eventually(func() *metav1.Condition { - gomega.Expect(k8sClient.Get(ctx, lookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).Should(gomega.Succeed()) return apimeta.FindStatusCondition(createdWorkload.Status.Conditions, kueue.WorkloadPodsReady) }, util.Timeout, util.Interval).Should(gomega.BeComparableTo(podsReadyTestSpec.beforeCondition, ignoreConditionTimestamps)) } @@ -470,7 +474,7 @@ var _ = ginkgo.Describe("Job controller when waitForPodsReady enabled", func() { gomega.Eventually(func() error { // the update may need to be retried due to a conflict as the workload gets // also updated due to setting of the job status. - if err := k8sClient.Get(ctx, lookupKey, createdWorkload); err != nil { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { return err } createdWorkload.Spec.Admission = nil @@ -480,7 +484,7 @@ var _ = ginkgo.Describe("Job controller when waitForPodsReady enabled", func() { ginkgo.By("Verify the PodsReady condition is added") gomega.Eventually(func() *metav1.Condition { - gomega.Expect(k8sClient.Get(ctx, lookupKey, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).Should(gomega.Succeed()) return apimeta.FindStatusCondition(createdWorkload.Status.Conditions, kueue.WorkloadPodsReady) }, util.Timeout, util.Interval).Should(gomega.BeComparableTo(podsReadyTestSpec.wantCondition, ignoreConditionTimestamps)) }, diff --git a/test/integration/controller/mpijob/mpijob_controller_test.go b/test/integration/controller/mpijob/mpijob_controller_test.go new file mode 100644 index 00000000000..57a20ccaea0 --- /dev/null +++ b/test/integration/controller/mpijob/mpijob_controller_test.go @@ -0,0 +1,557 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mpijob + +import ( + "fmt" + + "github.com/google/go-cmp/cmp/cmpopts" + kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + apimeta "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/utils/pointer" + + kueue "sigs.k8s.io/kueue/apis/kueue/v1alpha2" + "sigs.k8s.io/kueue/pkg/constants" + workloadmpijob "sigs.k8s.io/kueue/pkg/controller/workload/mpijob" + "sigs.k8s.io/kueue/pkg/util/testing" + "sigs.k8s.io/kueue/test/integration/framework" + "sigs.k8s.io/kueue/test/util" +) + +const ( + jobName = "test-job" + jobNamespace = "default" + labelKey = "cloud.provider.com/instance" + priorityClassName = "test-priority-class" + priorityValue = 10 +) + +var ignoreConditionTimestamps = cmpopts.IgnoreFields(metav1.Condition{}, "LastTransitionTime") + +var ( + wlLookupKey = types.NamespacedName{Name: workloadmpijob.GetWorkloadNameForMPIJob(jobName), Namespace: jobNamespace} +) + +// +kubebuilder:docs-gen:collapse=Imports + +var _ = ginkgo.Describe("Job controller", func() { + + ginkgo.BeforeEach(func() { + fwk = &framework.Framework{ + ManagerSetup: managerSetup(workloadmpijob.WithManageJobsWithoutQueueName(true)), + CRDPath: crdPath, + DepCRDPaths: []string{mpiCrdPath}, + } + + ctx, cfg, k8sClient = fwk.Setup() + }) + ginkgo.AfterEach(func() { + fwk.Teardown() + }) + + ginkgo.It("Should reconcile MPIJobs", func() { + ginkgo.By("checking the job gets suspended when created unsuspended") + priorityClass := testing.MakePriorityClass(priorityClassName). + PriorityValue(int32(priorityValue)).Obj() + gomega.Expect(k8sClient.Create(ctx, priorityClass)).Should(gomega.Succeed()) + + job := testing.MakeMPIJob(jobName, jobNamespace).PriorityClass(priorityClassName).Obj() + err := k8sClient.Create(ctx, job) + gomega.Expect(err).To(gomega.Succeed()) + createdJob := &kubeflow.MPIJob{} + + gomega.Eventually(func() bool { + if err := k8sClient.Get(ctx, types.NamespacedName{Name: jobName, Namespace: jobNamespace}, createdJob); err != nil { + return false + } + return createdJob.Spec.RunPolicy.Suspend != nil && *createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + + ginkgo.By("checking the workload is created without queue assigned") + createdWorkload := &kueue.Workload{} + gomega.Eventually(func() error { + return k8sClient.Get(ctx, wlLookupKey, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Expect(createdWorkload.Spec.QueueName).Should(gomega.Equal(""), "The Workload shouldn't have .spec.queueName set") + gomega.Expect(metav1.IsControlledBy(createdWorkload, createdJob)).To(gomega.BeTrue(), "The Workload should be owned by the Job") + + ginkgo.By("checking the workload is created with priority and priorityName") + gomega.Expect(createdWorkload.Spec.PriorityClassName).Should(gomega.Equal(priorityClassName)) + gomega.Expect(*createdWorkload.Spec.Priority).Should(gomega.Equal(int32(priorityValue))) + + ginkgo.By("checking the workload is updated with queue name when the job does") + jobQueueName := "test-queue" + createdJob.Annotations = map[string]string{constants.QueueAnnotation: jobQueueName} + gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed()) + gomega.Eventually(func() bool { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { + return false + } + return createdWorkload.Spec.QueueName == jobQueueName + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + + ginkgo.By("checking a second non-matching workload is deleted") + secondWl, _ := workloadmpijob.ConstructWorkloadFor(ctx, k8sClient, createdJob, scheme.Scheme) + secondWl.Name = workloadmpijob.GetWorkloadNameForMPIJob("second-workload") + secondWl.Spec.PodSets[0].Count += 1 + gomega.Expect(k8sClient.Create(ctx, secondWl)).Should(gomega.Succeed()) + gomega.Eventually(func() error { + wl := &kueue.Workload{} + key := types.NamespacedName{Name: secondWl.Name, Namespace: secondWl.Namespace} + return k8sClient.Get(ctx, key, wl) + }, util.Timeout, util.Interval).Should(testing.BeNotFoundError()) + // check the original wl is still there + gomega.Consistently(func() bool { + err := k8sClient.Get(ctx, wlLookupKey, createdWorkload) + return err == nil + }, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue()) + + ginkgo.By("checking the job is unsuspended when workload is assigned") + onDemandFlavor := testing.MakeResourceFlavor("on-demand").Label(labelKey, "on-demand").Obj() + gomega.Expect(k8sClient.Create(ctx, onDemandFlavor)).Should(gomega.Succeed()) + spotFlavor := testing.MakeResourceFlavor("spot").Label(labelKey, "spot").Obj() + gomega.Expect(k8sClient.Create(ctx, spotFlavor)).Should(gomega.Succeed()) + clusterQueue := testing.MakeClusterQueue("cluster-queue"). + Resource(testing.MakeResource(corev1.ResourceCPU). + Flavor(testing.MakeFlavor(onDemandFlavor.Name, "5").Obj()). + Flavor(testing.MakeFlavor(spotFlavor.Name, "5").Obj()). + Obj()).Obj() + createdWorkload.Spec.Admission = &kueue.Admission{ + ClusterQueue: kueue.ClusterQueueReference(clusterQueue.Name), + PodSetFlavors: []kueue.PodSetFlavors{{ + Name: "Launcher", + Flavors: map[corev1.ResourceName]string{ + corev1.ResourceCPU: onDemandFlavor.Name, + }, + }, { + Name: "Worker", + Flavors: map[corev1.ResourceName]string{ + corev1.ResourceCPU: spotFlavor.Name, + }, + }}, + } + lookupKey := types.NamespacedName{Name: jobName, Namespace: jobNamespace} + gomega.Expect(k8sClient.Update(ctx, createdWorkload)).Should(gomega.Succeed()) + gomega.Eventually(func() bool { + if err := k8sClient.Get(ctx, lookupKey, createdJob); err != nil { + return false + } + return !*createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + gomega.Eventually(func() bool { + ok, _ := testing.CheckLatestEvent(ctx, k8sClient, "Started", corev1.EventTypeNormal, fmt.Sprintf("Admitted by clusterQueue %v", clusterQueue.Name)) + return ok + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + gomega.Expect(len(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template.Spec.NodeSelector)).Should(gomega.Equal(1)) + gomega.Expect(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template.Spec.NodeSelector[labelKey]).Should(gomega.Equal(onDemandFlavor.Name)) + gomega.Expect(len(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Template.Spec.NodeSelector)).Should(gomega.Equal(1)) + gomega.Expect(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Template.Spec.NodeSelector[labelKey]).Should(gomega.Equal(spotFlavor.Name)) + gomega.Consistently(func() bool { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { + return false + } + return len(createdWorkload.Status.Conditions) == 0 + }, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue()) + + ginkgo.By("checking the job gets suspended when parallelism changes and the added node selectors are removed") + parallelism := pointer.Int32Deref(job.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Replicas, 1) + newParallelism := int32(parallelism + 1) + createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Replicas = &newParallelism + gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed()) + gomega.Eventually(func() bool { + if err := k8sClient.Get(ctx, lookupKey, createdJob); err != nil { + return false + } + return createdJob.Spec.RunPolicy.Suspend != nil && *createdJob.Spec.RunPolicy.Suspend && + len(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Template.Spec.NodeSelector) == 0 + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + gomega.Eventually(func() bool { + ok, _ := testing.CheckLatestEvent(ctx, k8sClient, "DeletedWorkload", corev1.EventTypeNormal, fmt.Sprintf("Deleted not matching Workload: %v", wlLookupKey.String())) + return ok + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + + ginkgo.By("checking the workload is updated with new count") + gomega.Eventually(func() bool { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { + return false + } + return createdWorkload.Spec.PodSets[1].Count == newParallelism + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + gomega.Expect(createdWorkload.Spec.Admission).Should(gomega.BeNil()) + + ginkgo.By("checking the job is unsuspended and selectors added when workload is assigned again") + createdWorkload.Spec.Admission = &kueue.Admission{ + ClusterQueue: kueue.ClusterQueueReference(clusterQueue.Name), + PodSetFlavors: []kueue.PodSetFlavors{{ + Name: "Launcher", + Flavors: map[corev1.ResourceName]string{ + corev1.ResourceCPU: onDemandFlavor.Name, + }, + }, { + Name: "Worker", + Flavors: map[corev1.ResourceName]string{ + corev1.ResourceCPU: spotFlavor.Name, + }, + }}, + } + gomega.Expect(k8sClient.Update(ctx, createdWorkload)).Should(gomega.Succeed()) + gomega.Eventually(func() bool { + if err := k8sClient.Get(ctx, lookupKey, createdJob); err != nil { + return false + } + return !*createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + gomega.Expect(len(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template.Spec.NodeSelector)).Should(gomega.Equal(1)) + gomega.Expect(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template.Spec.NodeSelector[labelKey]).Should(gomega.Equal(onDemandFlavor.Name)) + gomega.Expect(len(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Template.Spec.NodeSelector)).Should(gomega.Equal(1)) + gomega.Expect(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Template.Spec.NodeSelector[labelKey]).Should(gomega.Equal(spotFlavor.Name)) + gomega.Consistently(func() bool { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { + return false + } + return len(createdWorkload.Status.Conditions) == 0 + }, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue()) + + ginkgo.By("checking the workload is finished when job is completed") + createdJob.Status.Conditions = append(createdJob.Status.Conditions, + kubeflow.JobCondition{ + Type: kubeflow.JobSucceeded, + Status: corev1.ConditionTrue, + LastTransitionTime: metav1.Now(), + }) + gomega.Expect(k8sClient.Status().Update(ctx, createdJob)).Should(gomega.Succeed()) + gomega.Eventually(func() bool { + err := k8sClient.Get(ctx, wlLookupKey, createdWorkload) + if err != nil || len(createdWorkload.Status.Conditions) == 0 { + return false + } + + return createdWorkload.Status.Conditions[0].Type == kueue.WorkloadFinished && + createdWorkload.Status.Conditions[0].Status == metav1.ConditionTrue + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + }) +}) + +var _ = ginkgo.Describe("Job controller for workloads when only jobs with queue are managed", func() { + ginkgo.BeforeEach(func() { + fwk = &framework.Framework{ + ManagerSetup: managerSetup(), + CRDPath: crdPath, + DepCRDPaths: []string{mpiCrdPath}, + } + ctx, cfg, k8sClient = fwk.Setup() + }) + ginkgo.AfterEach(func() { + fwk.Teardown() + }) + ginkgo.It("Should reconcile jobs only when queue is set", func() { + ginkgo.By("checking the workload is not created when queue name is not set") + job := testing.MakeMPIJob(jobName, jobNamespace).Obj() + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + lookupKey := types.NamespacedName{Name: jobName, Namespace: jobNamespace} + createdJob := &kubeflow.MPIJob{} + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + + createdWorkload := &kueue.Workload{} + gomega.Consistently(func() bool { + return apierrors.IsNotFound(k8sClient.Get(ctx, wlLookupKey, createdWorkload)) + }, util.ConsistentDuration, util.Interval).Should(gomega.BeTrue()) + + ginkgo.By("checking the workload is created when queue name is set") + jobQueueName := "test-queue" + createdJob.Annotations = map[string]string{constants.QueueAnnotation: jobQueueName} + gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed()) + gomega.Eventually(func() error { + return k8sClient.Get(ctx, wlLookupKey, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + +}) + +var _ = ginkgo.Describe("Job controller when waitForPodsReady enabled", func() { + type podsReadyTestSpec struct { + beforeJobStatus *kubeflow.JobStatus + beforeCondition *metav1.Condition + jobStatus kubeflow.JobStatus + suspended bool + wantCondition *metav1.Condition + } + + ginkgo.BeforeEach(func() { + fwk = &framework.Framework{ + ManagerSetup: managerSetup(workloadmpijob.WithWaitForPodsReady(true)), + CRDPath: crdPath, + DepCRDPaths: []string{mpiCrdPath}, + } + ctx, cfg, k8sClient = fwk.Setup() + }) + ginkgo.AfterEach(func() { + fwk.Teardown() + }) + + ginkgo.DescribeTable("Single job at different stages of progress towards completion", + func(podsReadyTestSpec podsReadyTestSpec) { + ginkgo.By("Create a resource flavor") + defaultFlavor := testing.MakeResourceFlavor("default").Label(labelKey, "default").Obj() + gomega.Expect(k8sClient.Create(ctx, defaultFlavor)).Should(gomega.Succeed()) + + ginkgo.By("Create a job") + job := testing.MakeMPIJob(jobName, jobNamespace).Parallelism(2).Obj() + jobQueueName := "test-queue" + job.Annotations = map[string]string{constants.QueueAnnotation: jobQueueName} + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + lookupKey := types.NamespacedName{Name: jobName, Namespace: jobNamespace} + createdJob := &kubeflow.MPIJob{} + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + + ginkgo.By("Fetch the workload created for the job") + createdWorkload := &kueue.Workload{} + gomega.Eventually(func() error { + return k8sClient.Get(ctx, wlLookupKey, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Admit the workload created for the job") + createdWorkload.Spec.Admission = &kueue.Admission{ + ClusterQueue: kueue.ClusterQueueReference("foo"), + PodSetFlavors: []kueue.PodSetFlavors{{ + Name: "Launcher", + Flavors: map[corev1.ResourceName]string{ + corev1.ResourceCPU: defaultFlavor.Name, + }, + }, { + Name: "Worker", + Flavors: map[corev1.ResourceName]string{ + corev1.ResourceCPU: defaultFlavor.Name, + }, + }}, + } + gomega.Expect(k8sClient.Update(ctx, createdWorkload)).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + + ginkgo.By("Await for the job to be unsuspended") + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(false))) + + if podsReadyTestSpec.beforeJobStatus != nil { + ginkgo.By("Update the job status to simulate its initial progress towards completion") + createdJob.Status = *podsReadyTestSpec.beforeJobStatus + gomega.Expect(k8sClient.Status().Update(ctx, createdJob)).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + } + + if podsReadyTestSpec.beforeCondition != nil { + ginkgo.By("Update the workload status") + gomega.Eventually(func() *metav1.Condition { + gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + return apimeta.FindStatusCondition(createdWorkload.Status.Conditions, kueue.WorkloadPodsReady) + }, util.Timeout, util.Interval).Should(gomega.BeComparableTo(podsReadyTestSpec.beforeCondition, ignoreConditionTimestamps)) + } + + ginkgo.By("Update the job status to simulate its progress towards completion") + createdJob.Status = podsReadyTestSpec.jobStatus + gomega.Expect(k8sClient.Status().Update(ctx, createdJob)).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Get(ctx, lookupKey, createdJob)).Should(gomega.Succeed()) + + if podsReadyTestSpec.suspended { + ginkgo.By("Unset admission of the workload to suspend the job") + gomega.Eventually(func() error { + // the update may need to be retried due to a conflict as the workload gets + // also updated due to setting of the job status. + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { + return err + } + createdWorkload.Spec.Admission = nil + return k8sClient.Update(ctx, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + } + + ginkgo.By("Verify the PodsReady condition is added") + gomega.Eventually(func() *metav1.Condition { + gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).Should(gomega.Succeed()) + return apimeta.FindStatusCondition(createdWorkload.Status.Conditions, kueue.WorkloadPodsReady) + }, util.Timeout, util.Interval).Should(gomega.BeComparableTo(podsReadyTestSpec.wantCondition, ignoreConditionTimestamps)) + }, + ginkgo.Entry("No progress", podsReadyTestSpec{ + wantCondition: &metav1.Condition{ + Type: kueue.WorkloadPodsReady, + Status: metav1.ConditionFalse, + Reason: "PodsReady", + Message: "Not all pods are ready or succeeded", + }, + }), + ginkgo.Entry("Running MPIJob", podsReadyTestSpec{ + jobStatus: kubeflow.JobStatus{ + Conditions: []kubeflow.JobCondition{ + { + Type: kubeflow.JobRunning, + Status: corev1.ConditionTrue, + Reason: "Running", + }, + }, + }, + wantCondition: &metav1.Condition{ + Type: kueue.WorkloadPodsReady, + Status: metav1.ConditionTrue, + Reason: "PodsReady", + Message: "All pods were ready or succeeded since the workload admission", + }, + }), + ginkgo.Entry("Running MPIJob; PodsReady=False before", podsReadyTestSpec{ + beforeCondition: &metav1.Condition{ + Type: kueue.WorkloadPodsReady, + Status: metav1.ConditionFalse, + Reason: "PodsReady", + Message: "Not all pods are ready or succeeded", + }, + jobStatus: kubeflow.JobStatus{ + Conditions: []kubeflow.JobCondition{ + { + Type: kubeflow.JobRunning, + Status: corev1.ConditionTrue, + Reason: "Running", + }, + }, + }, + wantCondition: &metav1.Condition{ + Type: kueue.WorkloadPodsReady, + Status: metav1.ConditionTrue, + Reason: "PodsReady", + Message: "All pods were ready or succeeded since the workload admission", + }, + }), + ginkgo.Entry("Job suspended; PodsReady=True before", podsReadyTestSpec{ + beforeJobStatus: &kubeflow.JobStatus{ + Conditions: []kubeflow.JobCondition{ + { + Type: kubeflow.JobRunning, + Status: corev1.ConditionTrue, + Reason: "Running", + }, + }, + }, + beforeCondition: &metav1.Condition{ + Type: kueue.WorkloadPodsReady, + Status: metav1.ConditionTrue, + Reason: "PodsReady", + Message: "All pods were ready or succeeded since the workload admission", + }, + jobStatus: kubeflow.JobStatus{ + Conditions: []kubeflow.JobCondition{ + { + Type: kubeflow.JobRunning, + Status: corev1.ConditionFalse, + Reason: "Suspended", + }, + }, + }, + suspended: true, + wantCondition: &metav1.Condition{ + Type: kueue.WorkloadPodsReady, + Status: metav1.ConditionFalse, + Reason: "PodsReady", + Message: "Not all pods are ready or succeeded", + }, + }), + ) +}) + +var _ = ginkgo.Describe("Job controller interacting with scheduler", func() { + const ( + instanceKey = "cloud.provider.com/instance" + ) + + var ( + ns *corev1.Namespace + onDemandFlavor *kueue.ResourceFlavor + spotUntaintedFlavor *kueue.ResourceFlavor + clusterQueue *kueue.ClusterQueue + localQueue *kueue.LocalQueue + ) + + ginkgo.BeforeEach(func() { + fwk = &framework.Framework{ + ManagerSetup: managerAndSchedulerSetup(), + CRDPath: crdPath, + DepCRDPaths: []string{mpiCrdPath}, + } + ctx, cfg, k8sClient = fwk.Setup() + + ns = &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "core-", + }, + } + gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) + + onDemandFlavor = testing.MakeResourceFlavor("on-demand").Label(instanceKey, "on-demand").Obj() + gomega.Expect(k8sClient.Create(ctx, onDemandFlavor)).Should(gomega.Succeed()) + + spotUntaintedFlavor = testing.MakeResourceFlavor("spot-untainted").Label(instanceKey, "spot-untainted").Obj() + gomega.Expect(k8sClient.Create(ctx, spotUntaintedFlavor)).Should(gomega.Succeed()) + + clusterQueue = testing.MakeClusterQueue("dev-clusterqueue"). + Resource(testing.MakeResource(corev1.ResourceCPU). + Flavor(testing.MakeFlavor(spotUntaintedFlavor.Name, "5").Obj()). + Flavor(testing.MakeFlavor(onDemandFlavor.Name, "5").Obj()). + Obj()). + Obj() + gomega.Expect(k8sClient.Create(ctx, clusterQueue)).Should(gomega.Succeed()) + }) + + ginkgo.AfterEach(func() { + gomega.Expect(util.DeleteNamespace(ctx, k8sClient, ns)).To(gomega.Succeed()) + util.ExpectClusterQueueToBeDeleted(ctx, k8sClient, clusterQueue, true) + util.ExpectResourceFlavorToBeDeleted(ctx, k8sClient, onDemandFlavor, true) + gomega.Expect(util.DeleteResourceFlavor(ctx, k8sClient, spotUntaintedFlavor)).To(gomega.Succeed()) + + fwk.Teardown() + }) + + ginkgo.It("Should schedule jobs as they fit in their ClusterQueue", func() { + ginkgo.By("creating localQueue") + localQueue = testing.MakeLocalQueue("local-queue", ns.Name).ClusterQueue(clusterQueue.Name).Obj() + gomega.Expect(k8sClient.Create(ctx, localQueue)).Should(gomega.Succeed()) + + ginkgo.By("checking a dev job starts") + job := testing.MakeMPIJob("dev-job", ns.Name).Queue(localQueue.Name). + Request(kubeflow.MPIReplicaTypeLauncher, corev1.ResourceCPU, "3"). + Request(kubeflow.MPIReplicaTypeWorker, corev1.ResourceCPU, "4"). + Obj() + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + createdJob := &kubeflow.MPIJob{} + gomega.Eventually(func() *bool { + gomega.Expect(k8sClient.Get(ctx, types.NamespacedName{Name: job.Name, Namespace: job.Namespace}, createdJob)). + Should(gomega.Succeed()) + return createdJob.Spec.RunPolicy.Suspend + }, util.Timeout, util.Interval).Should(gomega.Equal(pointer.Bool(false))) + gomega.Expect(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template.Spec.NodeSelector[instanceKey]).Should(gomega.Equal(spotUntaintedFlavor.Name)) + gomega.Expect(createdJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker].Template.Spec.NodeSelector[instanceKey]).Should(gomega.Equal(onDemandFlavor.Name)) + util.ExpectPendingWorkloadsMetric(clusterQueue, 0, 0) + util.ExpectAdmittedActiveWorkloadsMetric(clusterQueue, 1) + + }) + +}) diff --git a/test/integration/controller/mpijob/suite_test.go b/test/integration/controller/mpijob/suite_test.go new file mode 100644 index 00000000000..f02b1df99bf --- /dev/null +++ b/test/integration/controller/mpijob/suite_test.go @@ -0,0 +1,99 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mpijob + +import ( + "context" + "path/filepath" + "testing" + + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" + "k8s.io/client-go/rest" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager" + + config "sigs.k8s.io/kueue/apis/config/v1alpha2" + "sigs.k8s.io/kueue/pkg/cache" + "sigs.k8s.io/kueue/pkg/constants" + "sigs.k8s.io/kueue/pkg/controller/core" + "sigs.k8s.io/kueue/pkg/controller/core/indexer" + "sigs.k8s.io/kueue/pkg/controller/workload/mpijob" + "sigs.k8s.io/kueue/pkg/queue" + "sigs.k8s.io/kueue/pkg/scheduler" + "sigs.k8s.io/kueue/test/integration/framework" + //+kubebuilder:scaffold:imports +) + +var ( + cfg *rest.Config + k8sClient client.Client + ctx context.Context + fwk *framework.Framework + crdPath = filepath.Join("..", "..", "..", "..", "config", "components", "crd", "bases") + mpiCrdPath = filepath.Join("..", "..", "..", "..", "dep-crds", "mpi-operator") +) + +func TestAPIs(t *testing.T) { + gomega.RegisterFailHandler(ginkgo.Fail) + + ginkgo.RunSpecs(t, + "MPIJob Controller Suite", + ) +} + +func managerSetup(opts ...mpijob.Option) framework.ManagerSetup { + return func(mgr manager.Manager, ctx context.Context) { + reconciler := mpijob.NewReconciler( + mgr.GetScheme(), + mgr.GetClient(), + mgr.GetEventRecorderFor(constants.JobControllerName), + opts...) + err := mpijob.SetupIndexes(ctx, mgr.GetFieldIndexer()) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + err = reconciler.SetupWithManager(mgr) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + err = mpijob.SetupMPIJobWebhook(mgr, opts...) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + } +} + +func managerAndSchedulerSetup(opts ...mpijob.Option) framework.ManagerSetup { + return func(mgr manager.Manager, ctx context.Context) { + err := indexer.Setup(ctx, mgr.GetFieldIndexer()) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + + cCache := cache.New(mgr.GetClient()) + queues := queue.NewManager(mgr.GetClient(), cCache) + + failedCtrl, err := core.SetupControllers(mgr, queues, cCache, &config.Configuration{}) + gomega.Expect(err).ToNot(gomega.HaveOccurred(), "controller", failedCtrl) + + err = mpijob.SetupIndexes(ctx, mgr.GetFieldIndexer()) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + err = mpijob.NewReconciler(mgr.GetScheme(), mgr.GetClient(), + mgr.GetEventRecorderFor(constants.JobControllerName), opts...).SetupWithManager(mgr) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + err = mpijob.SetupMPIJobWebhook(mgr, opts...) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + + sched := scheduler.New(queues, cCache, mgr.GetClient(), mgr.GetEventRecorderFor(constants.AdmissionName)) + go func() { + sched.Start(ctx) + }() + } +} diff --git a/test/integration/framework/framework.go b/test/integration/framework/framework.go index 39943aa52d6..48b30fc766a 100644 --- a/test/integration/framework/framework.go +++ b/test/integration/framework/framework.go @@ -23,6 +23,7 @@ import ( "net" "time" + kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" zaplog "go.uber.org/zap" @@ -42,6 +43,7 @@ type ManagerSetup func(manager.Manager, context.Context) type Framework struct { CRDPath string + DepCRDPaths []string WebhookPath string ManagerSetup ManagerSetup testEnv *envtest.Environment @@ -62,7 +64,7 @@ func (f *Framework) Setup() (context.Context, *rest.Config, client.Client) { ginkgo.By("bootstrapping test environment") f.testEnv = &envtest.Environment{ - CRDDirectoryPaths: []string{f.CRDPath}, + CRDDirectoryPaths: append(f.DepCRDPaths, f.CRDPath), ErrorIfCRDPathMissing: true, } webhookEnabled := len(f.WebhookPath) > 0 @@ -77,6 +79,9 @@ func (f *Framework) Setup() (context.Context, *rest.Config, client.Client) { err = kueue.AddToScheme(scheme.Scheme) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred()) + err = kubeflow.AddToScheme(scheme.Scheme) + gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred()) + // +kubebuilder:scaffold:scheme k8sClient, err := client.New(cfg, client.Options{Scheme: scheme.Scheme})