diff --git a/cmd/training-operator.v1/main.go b/cmd/training-operator.v1/main.go index 41a28ff590..8bbb5e1f3d 100644 --- a/cmd/training-operator.v1/main.go +++ b/cmd/training-operator.v1/main.go @@ -30,6 +30,7 @@ import ( utilruntime "k8s.io/apimachinery/pkg/util/runtime" clientgoscheme "k8s.io/client-go/kubernetes/scheme" _ "k8s.io/client-go/plugin/pkg/client/auth" + "k8s.io/client-go/util/flowcontrol" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/healthz" @@ -81,6 +82,8 @@ func main() { var webhookServerPort int var webhookServiceName string var webhookSecretName string + var clientQps int + var clientBurst int flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") @@ -95,7 +98,8 @@ func main() { flag.StringVar(&namespace, "namespace", os.Getenv(EnvKubeflowNamespace), "The namespace to monitor kubeflow jobs. If unset, it monitors all namespaces cluster-wide."+ "If set, it only monitors kubeflow jobs in the given namespace.") flag.IntVar(&controllerThreads, "controller-threads", 1, "Number of worker threads used by the controller.") - + flag.IntVar(&clientQps, "kube-api-qps", 20, "QPS indicates the maximum QPS to the master from this client.") + flag.IntVar(&clientBurst, "kube-api-burst", 30, "Maximum burst for throttle.") // PyTorch related flags flag.StringVar(&config.Config.PyTorchInitContainerImage, "pytorch-init-container-image", config.PyTorchInitContainerImageDefault, "The image for pytorch init container") @@ -131,7 +135,10 @@ func main() { } } - mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrl.Options{ + cfg := ctrl.GetConfigOrDie() + cfg.RateLimiter = flowcontrol.NewTokenBucketRateLimiter(float32(clientQps), clientBurst) + + mgr, err := ctrl.NewManager(cfg, ctrl.Options{ Scheme: scheme, Metrics: metricsserver.Options{ BindAddress: metricsAddr,