From 22d2635dd26d7c5dec4bc921c01b4c70dc6bf63d Mon Sep 17 00:00:00 2001 From: Honnix Date: Mon, 31 Jul 2023 16:04:56 +0200 Subject: [PATCH] Support gRPC config for agent-service plugin (#368) * Support gRPC config for agent-service plugin Signed-off-by: Hongxin Liang * Address comments Signed-off-by: Hongxin Liang * No deadline if timeout is 0 Signed-off-by: Hongxin Liang * Per task type grpc endpoint config Signed-off-by: Hongxin Liang * Rename config items according to comments Signed-off-by: Hongxin Liang --------- Signed-off-by: Hongxin Liang --- go/tasks/plugins/webapi/agent/config.go | 35 ++++++- go/tasks/plugins/webapi/agent/config_test.go | 24 +++++ .../plugins/webapi/agent/integration_test.go | 4 +- go/tasks/plugins/webapi/agent/plugin.go | 91 +++++++++++++++---- go/tasks/plugins/webapi/agent/plugin_test.go | 64 ++++++++++--- 5 files changed, 183 insertions(+), 35 deletions(-) diff --git a/go/tasks/plugins/webapi/agent/config.go b/go/tasks/plugins/webapi/agent/config.go index 14993b240..58355c5ee 100644 --- a/go/tasks/plugins/webapi/agent/config.go +++ b/go/tasks/plugins/webapi/agent/config.go @@ -39,8 +39,12 @@ var ( Value: 50, }, }, - DefaultGrpcEndpoint: "dns:///flyte-agent.flyte.svc.cluster.local:80", - SupportedTaskTypes: []string{"task_type_1", "task_type_2"}, + DefaultAgent: Agent{ + Endpoint: "dns:///flyte-agent.flyte.svc.cluster.local:80", + Insecure: true, + DefaultTimeout: config.Duration{Duration: 10 * time.Second}, + }, + SupportedTaskTypes: []string{"task_type_1", "task_type_2"}, } configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig) @@ -54,15 +58,36 @@ type Config struct { // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` - DefaultGrpcEndpoint string `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of agent service."` + // The default agent if there does not exist a more specific matching against task types + DefaultAgent Agent `json:"defaultAgent" pflag:",The default agent."` + + // The agents used to match against specific task types. {AgentId: Agent} + Agents map[string]*Agent `json:"agents" pflag:",The agents."` - // Maps endpoint to their plugin handler. {TaskType: Endpoint} - EndpointForTaskTypes map[string]string `json:"endpointForTaskTypes" pflag:"-,"` + // Maps task types to their agents. {TaskType: AgentId} + AgentForTaskTypes map[string]string `json:"agentForTaskTypes" pflag:"-,"` // SupportedTaskTypes is a list of task types that are supported by this plugin. SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."` } +type Agent struct { + // Endpoint points to an agent gRPC endpoint + Endpoint string `json:"endpoint"` + + // Insecure indicates whether the communication with the gRPC service is insecure + Insecure bool `json:"insecure"` + + // DefaultServiceConfig sets default gRPC service config; check https://github.com/grpc/grpc/blob/master/doc/service_config.md for more details + DefaultServiceConfig string `json:"defaultServiceConfig"` + + // Timeouts defines various RPC timeout values for different plugin operations: CreateTask, GetTask, DeleteTask; if not configured, defaults to DefaultTimeout + Timeouts map[string]config.Duration `json:"timeouts"` + + // DefaultTimeout gives the default RPC timeout if a more specific one is not defined in Timeouts; if neither DefaultTimeout nor Timeouts is defined for an operation, RPC timeout will not be enforced + DefaultTimeout config.Duration `json:"defaultTimeout"` +} + func GetConfig() *Config { return configSection.GetConfig().(*Config) } diff --git a/go/tasks/plugins/webapi/agent/config_test.go b/go/tasks/plugins/webapi/agent/config_test.go index e7201a2b9..a32805159 100644 --- a/go/tasks/plugins/webapi/agent/config_test.go +++ b/go/tasks/plugins/webapi/agent/config_test.go @@ -4,6 +4,8 @@ import ( "testing" "time" + "github.com/flyteorg/flytestdlib/config" + "github.com/stretchr/testify/assert" ) @@ -11,6 +13,28 @@ func TestGetAndSetConfig(t *testing.T) { cfg := defaultConfig cfg.WebAPI.Caching.Workers = 1 cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + cfg.DefaultAgent.Insecure = false + cfg.DefaultAgent.DefaultServiceConfig = "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}" + cfg.DefaultAgent.Timeouts = map[string]config.Duration{ + "CreateTask": { + Duration: 1 * time.Millisecond, + }, + "GetTask": { + Duration: 2 * time.Millisecond, + }, + "DeleteTask": { + Duration: 3 * time.Millisecond, + }, + } + cfg.DefaultAgent.DefaultTimeout = config.Duration{Duration: 10 * time.Second} + cfg.Agents = map[string]*Agent{ + "endpoint_1": { + Insecure: cfg.DefaultAgent.Insecure, + DefaultServiceConfig: cfg.DefaultAgent.DefaultServiceConfig, + Timeouts: cfg.DefaultAgent.Timeouts, + }, + } + cfg.AgentForTaskTypes = map[string]string{"task_type_1": "endpoint_1"} err := SetConfig(&cfg) assert.NoError(t, err) assert.Equal(t, &cfg, GetConfig()) diff --git a/go/tasks/plugins/webapi/agent/integration_test.go b/go/tasks/plugins/webapi/agent/integration_test.go index 30036ede7..a3990e631 100644 --- a/go/tasks/plugins/webapi/agent/integration_test.go +++ b/go/tasks/plugins/webapi/agent/integration_test.go @@ -55,11 +55,11 @@ func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ return &admin.DeleteTaskResponse{}, nil } -func mockGetClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +func mockGetClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { return &MockClient{}, nil } -func mockGetBadClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +func mockGetBadClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { return nil, fmt.Errorf("error") } diff --git a/go/tasks/plugins/webapi/agent/plugin.go b/go/tasks/plugins/webapi/agent/plugin.go index dbcb568d4..f8665aa60 100644 --- a/go/tasks/plugins/webapi/agent/plugin.go +++ b/go/tasks/plugins/webapi/agent/plugin.go @@ -2,10 +2,14 @@ package agent import ( "context" + "crypto/x509" "encoding/gob" "fmt" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flytestdlib/config" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" @@ -21,13 +25,13 @@ import ( "google.golang.org/grpc" ) -type GetClientFunc func(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) +type GetClientFunc func(ctx context.Context, endpoint *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) type Plugin struct { metricScope promutils.Scope cfg *Config getClient GetClientFunc - connectionCache map[string]*grpc.ClientConn + connectionCache map[*Agent]*grpc.ClientConn } type ResourceWrapper struct { @@ -66,14 +70,20 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() - endpoint := getFinalEndpoint(taskTemplate.Type, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes) + endpoint, err := getFinalEndpoint(taskTemplate.Type, p.cfg) + if err != nil { + return nil, nil, fmt.Errorf("failed to find agent endpoint with error: %v", err) + } client, err := p.getClient(ctx, endpoint, p.connectionCache) if err != nil { return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err) } + finalCtx, cancel := getFinalContext(ctx, "CreateTask", endpoint) + defer cancel() + taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) - res, err := client.CreateTask(ctx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix, TaskExecutionMetadata: &taskExecutionMetadata}) + res, err := client.CreateTask(finalCtx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix, TaskExecutionMetadata: &taskExecutionMetadata}) if err != nil { return nil, nil, err } @@ -89,13 +99,19 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(*ResourceMetaWrapper) - endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes) + endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg) + if err != nil { + return nil, fmt.Errorf("failed to find agent endpoint with error: %v", err) + } client, err := p.getClient(ctx, endpoint, p.connectionCache) if err != nil { return nil, fmt.Errorf("failed to connect to agent with error: %v", err) } - res, err := client.GetTask(ctx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) + finalCtx, cancel := getFinalContext(ctx, "GetTask", endpoint) + defer cancel() + + res, err := client.GetTask(finalCtx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) if err != nil { return nil, err } @@ -112,13 +128,19 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes) + endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg) + if err != nil { + return fmt.Errorf("failed to find agent endpoint with error: %v", err) + } client, err := p.getClient(ctx, endpoint, p.connectionCache) if err != nil { return fmt.Errorf("failed to connect to agent with error: %v", err) } - _, err = client.DeleteTask(ctx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) + finalCtx, cancel := getFinalContext(ctx, "DeleteTask", endpoint) + defer cancel() + + _, err = client.DeleteTask(finalCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) return err } @@ -145,24 +167,43 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State) } -func getFinalEndpoint(taskType, defaultEndpoint string, endpointForTaskTypes map[string]string) string { - if t, exists := endpointForTaskTypes[taskType]; exists { - return t +func getFinalEndpoint(taskType string, cfg *Config) (*Agent, error) { + if id, exists := cfg.AgentForTaskTypes[taskType]; exists { + if endpoint, exists := cfg.Agents[id]; exists { + return endpoint, nil + } + return nil, fmt.Errorf("no endpoint definition found for ID %s that matches task type %s", id, taskType) } - return defaultEndpoint + return &cfg.DefaultAgent, nil } -func getClientFunc(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +func getClientFunc(ctx context.Context, endpoint *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { conn, ok := connectionCache[endpoint] if ok { return service.NewAsyncAgentServiceClient(conn), nil } + var opts []grpc.DialOption - var err error - opts = append(opts, grpc.WithInsecure()) - conn, err = grpc.Dial(endpoint, opts...) + if endpoint.Insecure { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } else { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + + creds := credentials.NewClientTLSFromCert(pool, "") + opts = append(opts, grpc.WithTransportCredentials(creds)) + } + + if len(endpoint.DefaultServiceConfig) != 0 { + opts = append(opts, grpc.WithDefaultServiceConfig(endpoint.DefaultServiceConfig)) + } + + var err error + conn, err = grpc.Dial(endpoint.Endpoint, opts...) if err != nil { return nil, err } @@ -196,6 +237,22 @@ func buildTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionM } } +func getFinalTimeout(operation string, endpoint *Agent) config.Duration { + if t, exists := endpoint.Timeouts[operation]; exists { + return t + } + + return endpoint.DefaultTimeout +} + +func getFinalContext(ctx context.Context, operation string, endpoint *Agent) (context.Context, context.CancelFunc) { + timeout := getFinalTimeout(operation, endpoint).Duration + if timeout == 0 { + return ctx, func() {} + } + return context.WithTimeout(ctx, timeout) +} + func newAgentPlugin() webapi.PluginEntry { supportedTaskTypes := GetConfig().SupportedTaskTypes @@ -207,7 +264,7 @@ func newAgentPlugin() webapi.PluginEntry { metricScope: iCtx.MetricsScope(), cfg: GetConfig(), getClient: getClientFunc, - connectionCache: make(map[string]*grpc.ClientConn), + connectionCache: make(map[*Agent]*grpc.ClientConn), }, nil }, } diff --git a/go/tasks/plugins/webapi/agent/plugin_test.go b/go/tasks/plugins/webapi/agent/plugin_test.go index 174115eea..31c8fe034 100644 --- a/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/go/tasks/plugins/webapi/agent/plugin_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/flyteorg/flytestdlib/config" + "google.golang.org/grpc" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -17,16 +19,18 @@ func TestPlugin(t *testing.T) { fakeSetupContext := pluginCoreMocks.SetupContext{} fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) + cfg := defaultConfig + cfg.WebAPI.Caching.Workers = 1 + cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + cfg.DefaultAgent = Agent{Endpoint: "test-agent.flyte.svc.cluster.local:80"} + cfg.Agents = map[string]*Agent{"spark_agent": {Endpoint: "localhost:80"}} + cfg.AgentForTaskTypes = map[string]string{"spark": "spark_agent", "bar": "bar_agent"} + plugin := Plugin{ metricScope: fakeSetupContext.MetricsScope(), cfg: GetConfig(), } t.Run("get config", func(t *testing.T) { - cfg := defaultConfig - cfg.WebAPI.Caching.Workers = 1 - cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second - cfg.DefaultGrpcEndpoint = "test-agent.flyte.svc.cluster.local:80" - cfg.EndpointForTaskTypes = map[string]string{"spark": "localhost:80"} err := SetConfig(&cfg) assert.NoError(t, err) assert.Equal(t, cfg.WebAPI, plugin.GetConfig()) @@ -41,20 +45,58 @@ func TestPlugin(t *testing.T) { t.Run("tet newAgentPlugin", func(t *testing.T) { p := newAgentPlugin() assert.NotNil(t, p) - assert.Equal(t, p.ID, "agent-service") + assert.Equal(t, "agent-service", p.ID) assert.NotNil(t, p.PluginLoader) }) t.Run("test getFinalEndpoint", func(t *testing.T) { - endpoint := getFinalEndpoint("spark", "localhost:8080", map[string]string{"spark": "localhost:80"}) - assert.Equal(t, endpoint, "localhost:80") - endpoint = getFinalEndpoint("spark", "localhost:8080", map[string]string{}) - assert.Equal(t, endpoint, "localhost:8080") + endpoint, _ := getFinalEndpoint("spark", &cfg) + assert.Equal(t, cfg.Agents["spark_agent"].Endpoint, endpoint.Endpoint) + endpoint, _ = getFinalEndpoint("foo", &cfg) + assert.Equal(t, cfg.DefaultAgent.Endpoint, endpoint.Endpoint) + _, err := getFinalEndpoint("bar", &cfg) + assert.NotNil(t, err) }) t.Run("test getClientFunc", func(t *testing.T) { - client, err := getClientFunc(context.Background(), "localhost:80", map[string]*grpc.ClientConn{}) + client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80"}, map[*Agent]*grpc.ClientConn{}) + assert.NoError(t, err) + assert.NotNil(t, client) + }) + + t.Run("test getClientFunc more config", func(t *testing.T) { + client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*Agent]*grpc.ClientConn{}) + assert.NoError(t, err) + assert.NotNil(t, client) + }) + + t.Run("test getClientFunc cache hit", func(t *testing.T) { + connectionCache := make(map[*Agent]*grpc.ClientConn) + endpoint := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"} + + client, err := getClientFunc(context.Background(), endpoint, connectionCache) assert.NoError(t, err) assert.NotNil(t, client) + assert.NotNil(t, client, connectionCache[endpoint]) + + cachedClient, err := getClientFunc(context.Background(), endpoint, connectionCache) + assert.NoError(t, err) + assert.NotNil(t, cachedClient) + assert.Equal(t, client, cachedClient) + }) + + t.Run("test getFinalTimeout", func(t *testing.T) { + timeout := getFinalTimeout("CreateTask", &Agent{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) + assert.Equal(t, 1*time.Millisecond, timeout.Duration) + timeout = getFinalTimeout("DeleteTask", &Agent{Endpoint: "localhost:8080", DefaultTimeout: config.Duration{Duration: 10 * time.Second}}) + assert.Equal(t, 10*time.Second, timeout.Duration) + }) + + t.Run("test getFinalContext", func(t *testing.T) { + ctx, _ := getFinalContext(context.TODO(), "DeleteTask", &Agent{}) + assert.Equal(t, context.TODO(), ctx) + + ctx, _ = getFinalContext(context.TODO(), "CreateTask", &Agent{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) + assert.NotEqual(t, context.TODO(), ctx) }) }