Skip to content

Commit 13f7875

Browse files
authored
[kubectl-plugin] Add head/worker node selector option (#3228)
* add node selector option for kubectl plugin create cluster Signed-off-by: Troy Chiu <[email protected]> * nit Signed-off-by: Troy Chiu <[email protected]> --------- Signed-off-by: Troy Chiu <[email protected]>
1 parent 928d690 commit 13f7875

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

kubectl-plugin/pkg/cmd/create/create_cluster.go

+6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ type CreateClusterOptions struct {
2323
ioStreams *genericclioptions.IOStreams
2424
workerRayStartParams map[string]string
2525
headRayStartParams map[string]string
26+
headNodeSelectors map[string]string
27+
workerNodeSelectors map[string]string
2628
kubeContexter util.KubeContexter
2729
clusterName string
2830
rayVersion string
@@ -106,6 +108,8 @@ func NewCreateClusterCommand(streams genericclioptions.IOStreams) *cobra.Command
106108
cmd.Flags().BoolVar(&options.dryRun, "dry-run", false, "print the generated YAML instead of creating the cluster")
107109
cmd.Flags().BoolVar(&options.wait, "wait", false, "wait for the cluster to be provisioned before returning. Returns an error if the cluster is not provisioned by the timeout specified")
108110
cmd.Flags().DurationVar(&options.timeout, "timeout", defaultProvisionedTimeout, "the timeout for --wait")
111+
cmd.Flags().StringToStringVar(&options.headNodeSelectors, "head-node-selectors", nil, "Node selectors to apply to all head pods in the cluster (e.g. --head-node-selector=cloud.google.com/gke-accelerator=nvidia-l4,cloud.google.com/gke-nodepool=my-node-pool)")
112+
cmd.Flags().StringToStringVar(&options.workerNodeSelectors, "worker-node-selectors", nil, "Node selectors to apply to all worker pods in the cluster (e.g. --worker-node-selector=cloud.google.com/gke-accelerator=nvidia-l4,cloud.google.com/gke-nodepool=my-node-pool)")
109113

110114
options.configFlags.AddFlags(cmd.Flags())
111115
return cmd
@@ -182,6 +186,8 @@ func (options *CreateClusterOptions) Run(ctx context.Context, k8sClient client.C
182186
WorkerEphemeralStorage: options.workerEphemeralStorage,
183187
WorkerGPU: options.workerGPU,
184188
WorkerRayStartParams: options.workerRayStartParams,
189+
HeadNodeSelectors: options.headNodeSelectors,
190+
WorkerNodeSelectors: options.workerNodeSelectors,
185191
},
186192
}
187193

kubectl-plugin/pkg/util/generation/generation.go

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919
type RayClusterSpecObject struct {
2020
HeadRayStartParams map[string]string
2121
WorkerRayStartParams map[string]string
22+
HeadNodeSelectors map[string]string
23+
WorkerNodeSelectors map[string]string
2224
RayVersion string
2325
Image string
2426
HeadCPU string
@@ -107,6 +109,7 @@ func (rayClusterSpecObject *RayClusterSpecObject) generateRayClusterSpec() *rayv
107109
WithRayStartParams(headRayStartParams).
108110
WithTemplate(corev1ac.PodTemplateSpec().
109111
WithSpec(corev1ac.PodSpec().
112+
WithNodeSelector(rayClusterSpecObject.HeadNodeSelectors).
110113
WithContainers(corev1ac.Container().
111114
WithName("ray-head").
112115
WithImage(rayClusterSpecObject.Image).
@@ -122,6 +125,7 @@ func (rayClusterSpecObject *RayClusterSpecObject) generateRayClusterSpec() *rayv
122125
WithReplicas(rayClusterSpecObject.WorkerReplicas).
123126
WithTemplate(corev1ac.PodTemplateSpec().
124127
WithSpec(corev1ac.PodSpec().
128+
WithNodeSelector(rayClusterSpecObject.WorkerNodeSelectors).
125129
WithContainers(corev1ac.Container().
126130
WithName("ray-worker").
127131
WithImage(rayClusterSpecObject.Image).

kubectl-plugin/pkg/util/generation/generation_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,14 @@ func TestGenerateRayClusterSpec(t *testing.T) {
257257
WorkerCPU: "2",
258258
WorkerMemory: "10Gi",
259259
WorkerGPU: "0",
260+
HeadNodeSelectors: map[string]string{
261+
"head-selector1": "foo",
262+
"head-selector2": "bar",
263+
},
264+
WorkerNodeSelectors: map[string]string{
265+
"worker-selector1": "baz",
266+
"worker-selector2": "qux",
267+
},
260268
}
261269

262270
expected := &rayv1ac.RayClusterSpecApplyConfiguration{
@@ -299,6 +307,10 @@ func TestGenerateRayClusterSpec(t *testing.T) {
299307
},
300308
},
301309
},
310+
NodeSelector: map[string]string{
311+
"head-selector1": "foo",
312+
"head-selector2": "bar",
313+
},
302314
},
303315
},
304316
},
@@ -325,6 +337,10 @@ func TestGenerateRayClusterSpec(t *testing.T) {
325337
},
326338
},
327339
},
340+
NodeSelector: map[string]string{
341+
"worker-selector1": "baz",
342+
"worker-selector2": "qux",
343+
},
328344
},
329345
},
330346
},

0 commit comments

Comments
 (0)