Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Support gRPC config for agent-service plugin (#368)
Browse files Browse the repository at this point in the history
* Support gRPC config for agent-service plugin

Signed-off-by: Hongxin Liang <[email protected]>

* Address comments

Signed-off-by: Hongxin Liang <[email protected]>

* No deadline if timeout is 0

Signed-off-by: Hongxin Liang <[email protected]>

* Per task type grpc endpoint config

Signed-off-by: Hongxin Liang <[email protected]>

* Rename config items according to comments

Signed-off-by: Hongxin Liang <[email protected]>

---------

Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix authored Jul 31, 2023
1 parent 3405f89 commit 22d2635
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 35 deletions.
35 changes: 30 additions & 5 deletions go/tasks/plugins/webapi/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ var (
Value: 50,
},
},
DefaultGrpcEndpoint: "dns:///flyte-agent.flyte.svc.cluster.local:80",
SupportedTaskTypes: []string{"task_type_1", "task_type_2"},
DefaultAgent: Agent{
Endpoint: "dns:///flyte-agent.flyte.svc.cluster.local:80",
Insecure: true,
DefaultTimeout: config.Duration{Duration: 10 * time.Second},
},
SupportedTaskTypes: []string{"task_type_1", "task_type_2"},
}

configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig)
Expand All @@ -54,15 +58,36 @@ type Config struct {
// ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time
ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."`

DefaultGrpcEndpoint string `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of agent service."`
// The default agent if there does not exist a more specific matching against task types
DefaultAgent Agent `json:"defaultAgent" pflag:",The default agent."`

// The agents used to match against specific task types. {AgentId: Agent}
Agents map[string]*Agent `json:"agents" pflag:",The agents."`

// Maps endpoint to their plugin handler. {TaskType: Endpoint}
EndpointForTaskTypes map[string]string `json:"endpointForTaskTypes" pflag:"-,"`
// Maps task types to their agents. {TaskType: AgentId}
AgentForTaskTypes map[string]string `json:"agentForTaskTypes" pflag:"-,"`

// SupportedTaskTypes is a list of task types that are supported by this plugin.
SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."`
}

type Agent struct {
// Endpoint points to an agent gRPC endpoint
Endpoint string `json:"endpoint"`

// Insecure indicates whether the communication with the gRPC service is insecure
Insecure bool `json:"insecure"`

// DefaultServiceConfig sets default gRPC service config; check https://github.com/grpc/grpc/blob/master/doc/service_config.md for more details
DefaultServiceConfig string `json:"defaultServiceConfig"`

// Timeouts defines various RPC timeout values for different plugin operations: CreateTask, GetTask, DeleteTask; if not configured, defaults to DefaultTimeout
Timeouts map[string]config.Duration `json:"timeouts"`

// DefaultTimeout gives the default RPC timeout if a more specific one is not defined in Timeouts; if neither DefaultTimeout nor Timeouts is defined for an operation, RPC timeout will not be enforced
DefaultTimeout config.Duration `json:"defaultTimeout"`
}

func GetConfig() *Config {
return configSection.GetConfig().(*Config)
}
Expand Down
24 changes: 24 additions & 0 deletions go/tasks/plugins/webapi/agent/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,37 @@ import (
"testing"
"time"

"github.com/flyteorg/flytestdlib/config"

"github.com/stretchr/testify/assert"
)

func TestGetAndSetConfig(t *testing.T) {
cfg := defaultConfig
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
cfg.DefaultAgent.Insecure = false
cfg.DefaultAgent.DefaultServiceConfig = "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"
cfg.DefaultAgent.Timeouts = map[string]config.Duration{
"CreateTask": {
Duration: 1 * time.Millisecond,
},
"GetTask": {
Duration: 2 * time.Millisecond,
},
"DeleteTask": {
Duration: 3 * time.Millisecond,
},
}
cfg.DefaultAgent.DefaultTimeout = config.Duration{Duration: 10 * time.Second}
cfg.Agents = map[string]*Agent{
"endpoint_1": {
Insecure: cfg.DefaultAgent.Insecure,
DefaultServiceConfig: cfg.DefaultAgent.DefaultServiceConfig,
Timeouts: cfg.DefaultAgent.Timeouts,
},
}
cfg.AgentForTaskTypes = map[string]string{"task_type_1": "endpoint_1"}
err := SetConfig(&cfg)
assert.NoError(t, err)
assert.Equal(t, &cfg, GetConfig())
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _
return &admin.DeleteTaskResponse{}, nil
}

func mockGetClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
func mockGetClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
return &MockClient{}, nil
}

func mockGetBadClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
func mockGetBadClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
return nil, fmt.Errorf("error")
}

Expand Down
91 changes: 74 additions & 17 deletions go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package agent

import (
"context"
"crypto/x509"
"encoding/gob"
"fmt"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flytestdlib/config"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"

"google.golang.org/grpc/grpclog"

Expand All @@ -21,13 +25,13 @@ import (
"google.golang.org/grpc"
)

type GetClientFunc func(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)
type GetClientFunc func(ctx context.Context, endpoint *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)

type Plugin struct {
metricScope promutils.Scope
cfg *Config
getClient GetClientFunc
connectionCache map[string]*grpc.ClientConn
connectionCache map[*Agent]*grpc.ClientConn
}

type ResourceWrapper struct {
Expand Down Expand Up @@ -66,14 +70,20 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR

outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

endpoint := getFinalEndpoint(taskTemplate.Type, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
endpoint, err := getFinalEndpoint(taskTemplate.Type, p.cfg)
if err != nil {
return nil, nil, fmt.Errorf("failed to find agent endpoint with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

finalCtx, cancel := getFinalContext(ctx, "CreateTask", endpoint)
defer cancel()

taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
res, err := client.CreateTask(ctx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix, TaskExecutionMetadata: &taskExecutionMetadata})
res, err := client.CreateTask(finalCtx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix, TaskExecutionMetadata: &taskExecutionMetadata})
if err != nil {
return nil, nil, err
}
Expand All @@ -89,13 +99,19 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
metadata := taskCtx.ResourceMeta().(*ResourceMetaWrapper)

endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg)
if err != nil {
return nil, fmt.Errorf("failed to find agent endpoint with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

res, err := client.GetTask(ctx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
finalCtx, cancel := getFinalContext(ctx, "GetTask", endpoint)
defer cancel()

res, err := client.GetTask(finalCtx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
if err != nil {
return nil, err
}
Expand All @@ -112,13 +128,19 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
}
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)

endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg)
if err != nil {
return fmt.Errorf("failed to find agent endpoint with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return fmt.Errorf("failed to connect to agent with error: %v", err)
}

_, err = client.DeleteTask(ctx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
finalCtx, cancel := getFinalContext(ctx, "DeleteTask", endpoint)
defer cancel()

_, err = client.DeleteTask(finalCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
return err
}

Expand All @@ -145,24 +167,43 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State)
}

func getFinalEndpoint(taskType, defaultEndpoint string, endpointForTaskTypes map[string]string) string {
if t, exists := endpointForTaskTypes[taskType]; exists {
return t
func getFinalEndpoint(taskType string, cfg *Config) (*Agent, error) {
if id, exists := cfg.AgentForTaskTypes[taskType]; exists {
if endpoint, exists := cfg.Agents[id]; exists {
return endpoint, nil
}
return nil, fmt.Errorf("no endpoint definition found for ID %s that matches task type %s", id, taskType)
}

return defaultEndpoint
return &cfg.DefaultAgent, nil
}

func getClientFunc(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
func getClientFunc(ctx context.Context, endpoint *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
conn, ok := connectionCache[endpoint]
if ok {
return service.NewAsyncAgentServiceClient(conn), nil
}

var opts []grpc.DialOption
var err error

opts = append(opts, grpc.WithInsecure())
conn, err = grpc.Dial(endpoint, opts...)
if endpoint.Insecure {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}

creds := credentials.NewClientTLSFromCert(pool, "")
opts = append(opts, grpc.WithTransportCredentials(creds))
}

if len(endpoint.DefaultServiceConfig) != 0 {
opts = append(opts, grpc.WithDefaultServiceConfig(endpoint.DefaultServiceConfig))
}

var err error
conn, err = grpc.Dial(endpoint.Endpoint, opts...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -196,6 +237,22 @@ func buildTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionM
}
}

func getFinalTimeout(operation string, endpoint *Agent) config.Duration {
if t, exists := endpoint.Timeouts[operation]; exists {
return t
}

return endpoint.DefaultTimeout
}

func getFinalContext(ctx context.Context, operation string, endpoint *Agent) (context.Context, context.CancelFunc) {
timeout := getFinalTimeout(operation, endpoint).Duration
if timeout == 0 {
return ctx, func() {}
}
return context.WithTimeout(ctx, timeout)
}

func newAgentPlugin() webapi.PluginEntry {
supportedTaskTypes := GetConfig().SupportedTaskTypes

Expand All @@ -207,7 +264,7 @@ func newAgentPlugin() webapi.PluginEntry {
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[string]*grpc.ClientConn),
connectionCache: make(map[*Agent]*grpc.ClientConn),
}, nil
},
}
Expand Down
64 changes: 53 additions & 11 deletions go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"testing"
"time"

"github.com/flyteorg/flytestdlib/config"

"google.golang.org/grpc"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand All @@ -17,16 +19,18 @@ func TestPlugin(t *testing.T) {
fakeSetupContext := pluginCoreMocks.SetupContext{}
fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test"))

cfg := defaultConfig
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
cfg.DefaultAgent = Agent{Endpoint: "test-agent.flyte.svc.cluster.local:80"}
cfg.Agents = map[string]*Agent{"spark_agent": {Endpoint: "localhost:80"}}
cfg.AgentForTaskTypes = map[string]string{"spark": "spark_agent", "bar": "bar_agent"}

plugin := Plugin{
metricScope: fakeSetupContext.MetricsScope(),
cfg: GetConfig(),
}
t.Run("get config", func(t *testing.T) {
cfg := defaultConfig
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
cfg.DefaultGrpcEndpoint = "test-agent.flyte.svc.cluster.local:80"
cfg.EndpointForTaskTypes = map[string]string{"spark": "localhost:80"}
err := SetConfig(&cfg)
assert.NoError(t, err)
assert.Equal(t, cfg.WebAPI, plugin.GetConfig())
Expand All @@ -41,20 +45,58 @@ func TestPlugin(t *testing.T) {
t.Run("tet newAgentPlugin", func(t *testing.T) {
p := newAgentPlugin()
assert.NotNil(t, p)
assert.Equal(t, p.ID, "agent-service")
assert.Equal(t, "agent-service", p.ID)
assert.NotNil(t, p.PluginLoader)
})

t.Run("test getFinalEndpoint", func(t *testing.T) {
endpoint := getFinalEndpoint("spark", "localhost:8080", map[string]string{"spark": "localhost:80"})
assert.Equal(t, endpoint, "localhost:80")
endpoint = getFinalEndpoint("spark", "localhost:8080", map[string]string{})
assert.Equal(t, endpoint, "localhost:8080")
endpoint, _ := getFinalEndpoint("spark", &cfg)
assert.Equal(t, cfg.Agents["spark_agent"].Endpoint, endpoint.Endpoint)
endpoint, _ = getFinalEndpoint("foo", &cfg)
assert.Equal(t, cfg.DefaultAgent.Endpoint, endpoint.Endpoint)
_, err := getFinalEndpoint("bar", &cfg)
assert.NotNil(t, err)
})

t.Run("test getClientFunc", func(t *testing.T) {
client, err := getClientFunc(context.Background(), "localhost:80", map[string]*grpc.ClientConn{})
client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80"}, map[*Agent]*grpc.ClientConn{})
assert.NoError(t, err)
assert.NotNil(t, client)
})

t.Run("test getClientFunc more config", func(t *testing.T) {
client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*Agent]*grpc.ClientConn{})
assert.NoError(t, err)
assert.NotNil(t, client)
})

t.Run("test getClientFunc cache hit", func(t *testing.T) {
connectionCache := make(map[*Agent]*grpc.ClientConn)
endpoint := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}

client, err := getClientFunc(context.Background(), endpoint, connectionCache)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.NotNil(t, client, connectionCache[endpoint])

cachedClient, err := getClientFunc(context.Background(), endpoint, connectionCache)
assert.NoError(t, err)
assert.NotNil(t, cachedClient)
assert.Equal(t, client, cachedClient)
})

t.Run("test getFinalTimeout", func(t *testing.T) {
timeout := getFinalTimeout("CreateTask", &Agent{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}})
assert.Equal(t, 1*time.Millisecond, timeout.Duration)
timeout = getFinalTimeout("DeleteTask", &Agent{Endpoint: "localhost:8080", DefaultTimeout: config.Duration{Duration: 10 * time.Second}})
assert.Equal(t, 10*time.Second, timeout.Duration)
})

t.Run("test getFinalContext", func(t *testing.T) {
ctx, _ := getFinalContext(context.TODO(), "DeleteTask", &Agent{})
assert.Equal(t, context.TODO(), ctx)

ctx, _ = getFinalContext(context.TODO(), "CreateTask", &Agent{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}})
assert.NotEqual(t, context.TODO(), ctx)
})
}

0 comments on commit 22d2635

Please sign in to comment.