diff --git a/.github/workflows/validate_examples.yaml b/.github/workflows/validate_examples.yaml index fc67fd1d..28e5a692 100644 --- a/.github/workflows/validate_examples.yaml +++ b/.github/workflows/validate_examples.yaml @@ -14,15 +14,15 @@ on: workflow_dispatch: inputs: daprdapr_commit: - description: 'Dapr/Dapr commit to build custom daprd from' + description: "Dapr/Dapr commit to build custom daprd from" required: false - default: '' + default: "" daprcli_commit: - description: 'Dapr/CLI commit to build custom dapr CLI from' + description: "Dapr/CLI commit to build custom dapr CLI from" required: false - default: '' + default: "" repository_dispatch: - types: [ validate-examples ] + types: [validate-examples] merge_group: jobs: setup: @@ -154,7 +154,17 @@ jobs: strategy: fail-fast: false matrix: - examples: [ "actor", "configuration", "grpc-service", "hello-world", "pubsub", "service", "socket" ] + examples: + [ + "actor", + "configuration", + "grpc-service", + "hello-world", + "pubsub", + "service", + "socket", + "workflow", + ] steps: - name: Check out code onto GOPATH uses: actions/checkout@v4 diff --git a/Makefile b/Makefile index 0a63249e..29365795 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,7 @@ cover: ## Displays test coverage in the client and service packages go test -coverprofile=cover-client.out ./client && go tool cover -html=cover-client.out go test -coverprofile=cover-grpc.out ./service/grpc && go tool cover -html=cover-grpc.out go test -coverprofile=cover-http.out ./service/http && go tool cover -html=cover-http.out + go test -coverprofile=cover-workflow.out ./workflow && go tool cover -html=cover-workflow.out .PHONY: lint lint: check-lint ## Lints the entire project diff --git a/client/client.go b/client/client.go index b8f19660..207d5861 100644 --- a/client/client.go +++ b/client/client.go @@ -209,8 +209,31 @@ type Client interface { // ImplActorClientStub is to impl user defined actor client stub ImplActorClientStub(actorClientStub actor.Client, opt ...config.Option) + // StartWorkflowBeta1 starts a workflow. + StartWorkflowBeta1(ctx context.Context, req *StartWorkflowRequest) (*StartWorkflowResponse, error) + + // GetWorkflowBeta1 gets a workflow. + GetWorkflowBeta1(ctx context.Context, req *GetWorkflowRequest) (*GetWorkflowResponse, error) + + // PurgeWorkflowBeta1 purges a workflow. + PurgeWorkflowBeta1(ctx context.Context, req *PurgeWorkflowRequest) error + + // TerminateWorkflowBeta1 terminates a workflow. + TerminateWorkflowBeta1(ctx context.Context, req *TerminateWorkflowRequest) error + + // PauseWorkflowBeta1 pauses a workflow. + PauseWorkflowBeta1(ctx context.Context, req *PauseWorkflowRequest) error + + // ResumeWorkflowBeta1 resumes a workflow. + ResumeWorkflowBeta1(ctx context.Context, req *ResumeWorkflowRequest) error + + // RaiseEventWorkflowBeta1 raises an event for a workflow. + RaiseEventWorkflowBeta1(ctx context.Context, req *RaiseEventWorkflowRequest) error + // GrpcClient returns the base grpc client if grpc is used and nil otherwise GrpcClient() pb.DaprClient + + GrpcClientConn() *grpc.ClientConn } // NewClient instantiates Dapr client using DAPR_GRPC_PORT environment variable as port. diff --git a/client/client_test.go b/client/client_test.go index 65fa1529..4fc7e603 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -39,8 +39,9 @@ import ( ) const ( - testBufSize = 1024 * 1024 - testSocket = "/tmp/dapr.socket" + testBufSize = 1024 * 1024 + testSocket = "/tmp/dapr.socket" + testWorkflowFailureID = "test_failure_id" ) var testClient Client @@ -500,6 +501,62 @@ func (s *testDaprServer) UnsubscribeConfiguration(ctx context.Context, in *pb.Un return &pb.UnsubscribeConfigurationResponse{Ok: true}, nil } +func (s *testDaprServer) StartWorkflowBeta1(ctx context.Context, in *pb.StartWorkflowRequest) (*pb.StartWorkflowResponse, error) { + if in.GetInstanceId() == testWorkflowFailureID { + return nil, errors.New("test failure") + } + return &pb.StartWorkflowResponse{ + InstanceId: in.GetInstanceId(), + }, nil +} + +func (s *testDaprServer) GetWorkflowBeta1(ctx context.Context, in *pb.GetWorkflowRequest) (*pb.GetWorkflowResponse, error) { + if in.GetInstanceId() == testWorkflowFailureID { + return nil, errors.New("test failure") + } + return &pb.GetWorkflowResponse{ + InstanceId: in.GetInstanceId(), + WorkflowName: "TestWorkflowName", + RuntimeStatus: "Running", + Properties: make(map[string]string), + }, nil +} + +func (s *testDaprServer) PurgeWorkflowBeta1(ctx context.Context, in *pb.PurgeWorkflowRequest) (*emptypb.Empty, error) { + if in.GetInstanceId() == testWorkflowFailureID { + return nil, errors.New("test failure") + } + return &emptypb.Empty{}, nil +} + +func (s *testDaprServer) TerminateWorkflowBeta1(ctx context.Context, in *pb.TerminateWorkflowRequest) (*emptypb.Empty, error) { + if in.GetInstanceId() == testWorkflowFailureID { + return nil, errors.New("test failure") + } + return &emptypb.Empty{}, nil +} + +func (s *testDaprServer) PauseWorkflowBeta1(ctx context.Context, in *pb.PauseWorkflowRequest) (*emptypb.Empty, error) { + if in.GetInstanceId() == testWorkflowFailureID { + return nil, errors.New("test failure") + } + return &emptypb.Empty{}, nil +} + +func (s *testDaprServer) ResumeWorkflowBeta1(ctx context.Context, in *pb.ResumeWorkflowRequest) (*emptypb.Empty, error) { + if in.GetInstanceId() == testWorkflowFailureID { + return nil, errors.New("test failure") + } + return &emptypb.Empty{}, nil +} + +func (s *testDaprServer) RaiseEventWorkflowBeta1(ctx context.Context, in *pb.RaiseEventWorkflowRequest) (*emptypb.Empty, error) { + if in.GetInstanceId() == testWorkflowFailureID { + return nil, errors.New("test failure") + } + return &emptypb.Empty{}, nil +} + func TestGrpcClient(t *testing.T) { protoClient := pb.NewDaprClient(nil) client := &GRPCClient{protoClient: protoClient} diff --git a/client/workflow.go b/client/workflow.go new file mode 100644 index 00000000..75b105d3 --- /dev/null +++ b/client/workflow.go @@ -0,0 +1,268 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package client + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + + pb "github.com/dapr/dapr/pkg/proto/runtime/v1" +) + +const ( + DefaultWorkflowComponent = "dapr" +) + +type StartWorkflowRequest struct { + InstanceID string // Optional instance identifier + WorkflowComponent string + WorkflowName string + Options map[string]string // Optional metadata + Input any // Optional input + SendRawInput bool // Set to True in order to disable serialization on the input +} + +type StartWorkflowResponse struct { + InstanceID string +} + +type GetWorkflowRequest struct { + InstanceID string + WorkflowComponent string +} + +type GetWorkflowResponse struct { + InstanceID string + WorkflowName string + CreatedAt time.Time + LastUpdatedAt time.Time + RuntimeStatus string + Properties map[string]string +} + +type PurgeWorkflowRequest struct { + InstanceID string + WorkflowComponent string +} + +type TerminateWorkflowRequest struct { + InstanceID string + WorkflowComponent string +} + +type PauseWorkflowRequest struct { + InstanceID string + WorkflowComponent string +} + +type ResumeWorkflowRequest struct { + InstanceID string + WorkflowComponent string +} + +type RaiseEventWorkflowRequest struct { + InstanceID string + WorkflowComponent string + EventName string + EventData any + SendRawData bool // Set to True in order to disable serialization on the data +} + +// StartWorkflowBeta1 starts a workflow using the beta1 spec. +func (c *GRPCClient) StartWorkflowBeta1(ctx context.Context, req *StartWorkflowRequest) (*StartWorkflowResponse, error) { + if req.InstanceID == "" { + req.InstanceID = uuid.New().String() + } + if req.WorkflowComponent == "" { + req.WorkflowComponent = DefaultWorkflowComponent + } + if req.WorkflowName == "" { + return nil, errors.New("failed to start workflow: WorkflowName must be supplied") + } + + var input []byte + var err error + if req.SendRawInput { + var ok bool + if input, ok = req.Input.([]byte); !ok { + return nil, errors.New("failed to start workflow: sendrawinput is true however, input is not a byte slice") + } + } else { + input, err = marshalInput(req.Input) + if err != nil { + return nil, fmt.Errorf("failed to start workflow: %v", err) + } + } + + resp, err := c.protoClient.StartWorkflowBeta1(ctx, &pb.StartWorkflowRequest{ + InstanceId: req.InstanceID, + WorkflowComponent: req.WorkflowComponent, + WorkflowName: req.WorkflowName, + Options: req.Options, + Input: input, + }) + if err != nil { + return nil, fmt.Errorf("failed to start workflow instance: %v", err) + } + return &StartWorkflowResponse{ + InstanceID: resp.GetInstanceId(), + }, nil +} + +// GetWorkflowBeta1 gets the status of a workflow using the beta1 spec. +func (c *GRPCClient) GetWorkflowBeta1(ctx context.Context, req *GetWorkflowRequest) (*GetWorkflowResponse, error) { + if req.InstanceID == "" { + return nil, errors.New("failed to get workflow status: InstanceID must be supplied") + } + if req.WorkflowComponent == "" { + req.WorkflowComponent = DefaultWorkflowComponent + } + resp, err := c.protoClient.GetWorkflowBeta1(ctx, &pb.GetWorkflowRequest{ + InstanceId: req.InstanceID, + WorkflowComponent: req.WorkflowComponent, + }) + if err != nil { + return nil, fmt.Errorf("failed to get workflow status: %v", err) + } + return &GetWorkflowResponse{ + InstanceID: resp.GetInstanceId(), + WorkflowName: resp.GetWorkflowName(), + CreatedAt: resp.GetCreatedAt().AsTime(), + LastUpdatedAt: resp.GetLastUpdatedAt().AsTime(), + RuntimeStatus: resp.GetRuntimeStatus(), + Properties: resp.GetProperties(), + }, nil +} + +// PurgeWorkflowBeta1 removes all metadata relating to a specific workflow using the beta1 spec. +func (c *GRPCClient) PurgeWorkflowBeta1(ctx context.Context, req *PurgeWorkflowRequest) error { + if req.InstanceID == "" { + return errors.New("failed to purge workflow: InstanceID must be supplied") + } + if req.WorkflowComponent == "" { + req.WorkflowComponent = DefaultWorkflowComponent + } + _, err := c.protoClient.PurgeWorkflowBeta1(ctx, &pb.PurgeWorkflowRequest{ + InstanceId: req.InstanceID, + WorkflowComponent: req.WorkflowComponent, + }) + if err != nil { + return fmt.Errorf("failed to purge workflow: %v", err) + } + return nil +} + +// TerminateWorkflowBeta1 stops a workflow using the beta1 spec. +func (c *GRPCClient) TerminateWorkflowBeta1(ctx context.Context, req *TerminateWorkflowRequest) error { + if req.InstanceID == "" { + return errors.New("failed to terminate workflow: InstanceID must be supplied") + } + if req.WorkflowComponent == "" { + req.WorkflowComponent = DefaultWorkflowComponent + } + _, err := c.protoClient.TerminateWorkflowBeta1(ctx, &pb.TerminateWorkflowRequest{ + InstanceId: req.InstanceID, + WorkflowComponent: req.WorkflowComponent, + }) + if err != nil { + return fmt.Errorf("failed to terminate workflow: %v", err) + } + return nil +} + +// PauseWorkflowBeta1 pauses a workflow that can be resumed later using the beta1 spec. +func (c *GRPCClient) PauseWorkflowBeta1(ctx context.Context, req *PauseWorkflowRequest) error { + if req.InstanceID == "" { + return errors.New("failed to pause workflow: InstanceID must be supplied") + } + if req.WorkflowComponent == "" { + req.WorkflowComponent = DefaultWorkflowComponent + } + _, err := c.protoClient.PauseWorkflowBeta1(ctx, &pb.PauseWorkflowRequest{ + InstanceId: req.InstanceID, + WorkflowComponent: req.WorkflowComponent, + }) + if err != nil { + return fmt.Errorf("failed to pause workflow: %v", err) + } + return nil +} + +// ResumeWorkflowBeta1 resumes a paused workflow using the beta1 spec. +func (c *GRPCClient) ResumeWorkflowBeta1(ctx context.Context, req *ResumeWorkflowRequest) error { + if req.InstanceID == "" { + return errors.New("failed to resume workflow: InstanceID must be supplied") + } + if req.WorkflowComponent == "" { + req.WorkflowComponent = DefaultWorkflowComponent + } + _, err := c.protoClient.ResumeWorkflowBeta1(ctx, &pb.ResumeWorkflowRequest{ + InstanceId: req.InstanceID, + WorkflowComponent: req.WorkflowComponent, + }) + if err != nil { + return fmt.Errorf("failed to resume workflow: %v", err) + } + return nil +} + +// RaiseEventWorkflowBeta1 raises an event on a workflow using the beta1 spec. +func (c *GRPCClient) RaiseEventWorkflowBeta1(ctx context.Context, req *RaiseEventWorkflowRequest) error { + if req.InstanceID == "" { + return errors.New("failed to raise event on workflow: InstanceID must be supplied") + } + if req.WorkflowComponent == "" { + req.WorkflowComponent = DefaultWorkflowComponent + } + if req.EventName == "" { + return errors.New("failed to raise event on workflow: EventName must be supplied") + } + var eventData []byte + var err error + if req.SendRawData { + var ok bool + if eventData, ok = req.EventData.([]byte); !ok { + return errors.New("failed to raise event on workflow: SendRawData is true however, eventData is not a byte slice") + } + } else { + eventData, err = marshalInput(req.EventData) + if err != nil { + return fmt.Errorf("failed to raise an event on workflow: %v", err) + } + } + + _, err = c.protoClient.RaiseEventWorkflowBeta1(ctx, &pb.RaiseEventWorkflowRequest{ + InstanceId: req.InstanceID, + WorkflowComponent: req.WorkflowComponent, + EventName: req.EventName, + EventData: eventData, + }) + if err != nil { + return fmt.Errorf("failed to raise event on workflow: %v", err) + } + return nil +} + +func marshalInput(input any) (data []byte, err error) { + if input == nil { + return nil, nil + } + return json.Marshal(input) +} diff --git a/client/workflow_test.go b/client/workflow_test.go new file mode 100644 index 00000000..3beee0e9 --- /dev/null +++ b/client/workflow_test.go @@ -0,0 +1,375 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package client + +import ( + "context" + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/stretchr/testify/assert" +) + +func TestMarshalInput(t *testing.T) { + var input any + t.Run("string", func(t *testing.T) { + input = "testString" + data, err := marshalInput(input) + require.NoError(t, err) + assert.Equal(t, []byte{0x22, 0x74, 0x65, 0x73, 0x74, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x22}, data) + }) +} + +func TestWorkflowBeta1(t *testing.T) { + ctx := context.Background() + + // 1: StartWorkflow + t.Run("start workflow - valid (without id)", func(t *testing.T) { + resp, err := testClient.StartWorkflowBeta1(ctx, &StartWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + WorkflowName: "TestWorkflow", + }) + require.NoError(t, err) + assert.NotNil(t, resp.InstanceID) + }) + t.Run("start workflow - valid (with id)", func(t *testing.T) { + resp, err := testClient.StartWorkflowBeta1(ctx, &StartWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "dapr", + WorkflowName: "TestWorkflow", + }) + require.NoError(t, err) + assert.Equal(t, "TestID", resp.InstanceID) + }) + t.Run("start workflow - valid (without component name)", func(t *testing.T) { + resp, err := testClient.StartWorkflowBeta1(ctx, &StartWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "", + WorkflowName: "TestWorkflow", + }) + require.NoError(t, err) + assert.Equal(t, "TestID", resp.InstanceID) + }) + t.Run("start workflow - rpc failure", func(t *testing.T) { + resp, err := testClient.StartWorkflowBeta1(ctx, &StartWorkflowRequest{ + InstanceID: testWorkflowFailureID, + WorkflowComponent: "dapr", + WorkflowName: "TestWorkflow", + }) + require.Error(t, err) + assert.Nil(t, resp) + }) + t.Run("start workflow - grpc failure", func(t *testing.T) { + resp, err := testClient.StartWorkflowBeta1(ctx, &StartWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + WorkflowName: "", + }) + require.Error(t, err) + assert.Nil(t, resp) + }) + t.Run("start workflow - cannot serialize input", func(t *testing.T) { + resp, err := testClient.StartWorkflowBeta1(ctx, &StartWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + WorkflowName: "TestWorkflow", + Input: math.NaN(), + SendRawInput: false, + }) + require.Error(t, err) + assert.Nil(t, resp) + }) + t.Run("start workflow - raw input", func(t *testing.T) { + resp, err := testClient.StartWorkflowBeta1(ctx, &StartWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + WorkflowName: "TestWorkflow", + Input: []byte("stringtest"), + SendRawInput: true, + }) + require.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("start workflow - raw input (invalid)", func(t *testing.T) { + resp, err := testClient.StartWorkflowBeta1(ctx, &StartWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + WorkflowName: "TestWorkflow", + Input: "test string", + SendRawInput: true, + }) + require.Error(t, err) + assert.Nil(t, resp) + }) + + // 2: GetWorkflow + t.Run("get workflow", func(t *testing.T) { + resp, err := testClient.GetWorkflowBeta1(ctx, &GetWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "dapr", + }) + require.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("get workflow - valid", func(t *testing.T) { + resp, err := testClient.GetWorkflowBeta1(ctx, &GetWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "dapr", + }) + require.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("get workflow - valid (without component)", func(t *testing.T) { + resp, err := testClient.GetWorkflowBeta1(ctx, &GetWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "", + }) + require.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("get workflow - invalid id", func(t *testing.T) { + resp, err := testClient.GetWorkflowBeta1(ctx, &GetWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + }) + require.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("get workflow - grpc fail", func(t *testing.T) { + resp, err := testClient.GetWorkflowBeta1(ctx, &GetWorkflowRequest{ + InstanceID: testWorkflowFailureID, + WorkflowComponent: "dapr", + }) + require.Error(t, err) + assert.Nil(t, resp) + }) + + // 3: PauseWorkflow + t.Run("pause workflow", func(t *testing.T) { + err := testClient.PauseWorkflowBeta1(ctx, &PauseWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "dapr", + }) + require.NoError(t, err) + }) + + t.Run("pause workflow - valid (without component)", func(t *testing.T) { + err := testClient.PauseWorkflowBeta1(ctx, &PauseWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "", + }) + require.NoError(t, err) + }) + + t.Run("pause workflow invalid instanceid", func(t *testing.T) { + err := testClient.PauseWorkflowBeta1(ctx, &PauseWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + }) + require.Error(t, err) + }) + + t.Run("pause workflow", func(t *testing.T) { + err := testClient.PauseWorkflowBeta1(ctx, &PauseWorkflowRequest{ + InstanceID: testWorkflowFailureID, + WorkflowComponent: "dapr", + }) + require.Error(t, err) + }) + + // 4: ResumeWorkflow + t.Run("resume workflow", func(t *testing.T) { + err := testClient.ResumeWorkflowBeta1(ctx, &ResumeWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "dapr", + }) + require.NoError(t, err) + }) + + t.Run("resume workflow - valid (without component)", func(t *testing.T) { + err := testClient.ResumeWorkflowBeta1(ctx, &ResumeWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "", + }) + require.NoError(t, err) + }) + + t.Run("resume workflow - invalid instanceid", func(t *testing.T) { + err := testClient.ResumeWorkflowBeta1(ctx, &ResumeWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + }) + require.Error(t, err) + }) + + t.Run("resume workflow - grpc fail", func(t *testing.T) { + err := testClient.ResumeWorkflowBeta1(ctx, &ResumeWorkflowRequest{ + InstanceID: testWorkflowFailureID, + WorkflowComponent: "dapr", + }) + require.Error(t, err) + }) + + // 5: TerminateWorkflow + t.Run("terminate workflow", func(t *testing.T) { + err := testClient.TerminateWorkflowBeta1(ctx, &TerminateWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "dapr", + }) + require.NoError(t, err) + }) + + t.Run("terminate workflow - valid (without component)", func(t *testing.T) { + err := testClient.TerminateWorkflowBeta1(ctx, &TerminateWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "", + }) + require.NoError(t, err) + }) + + t.Run("terminate workflow - invalid instanceid", func(t *testing.T) { + err := testClient.TerminateWorkflowBeta1(ctx, &TerminateWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + }) + require.Error(t, err) + }) + + t.Run("terminate workflow - grpc failure", func(t *testing.T) { + err := testClient.TerminateWorkflowBeta1(ctx, &TerminateWorkflowRequest{ + InstanceID: testWorkflowFailureID, + WorkflowComponent: "dapr", + }) + require.Error(t, err) + }) + + // 6: RaiseEventWorkflow + t.Run("raise event workflow", func(t *testing.T) { + err := testClient.RaiseEventWorkflowBeta1(ctx, &RaiseEventWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "dapr", + EventName: "TestEvent", + }) + require.NoError(t, err) + }) + + t.Run("raise event workflow - valid (without component)", func(t *testing.T) { + err := testClient.RaiseEventWorkflowBeta1(ctx, &RaiseEventWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "", + EventName: "TestEvent", + }) + require.NoError(t, err) + }) + + t.Run("raise event workflow - invalid instanceid", func(t *testing.T) { + err := testClient.RaiseEventWorkflowBeta1(ctx, &RaiseEventWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + EventName: "TestEvent", + }) + require.Error(t, err) + }) + + t.Run("raise event workflow - invalid eventname", func(t *testing.T) { + err := testClient.RaiseEventWorkflowBeta1(ctx, &RaiseEventWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "dapr", + EventName: "", + }) + require.Error(t, err) + }) + + t.Run("raise event workflow - grpc failure", func(t *testing.T) { + err := testClient.RaiseEventWorkflowBeta1(ctx, &RaiseEventWorkflowRequest{ + InstanceID: testWorkflowFailureID, + WorkflowComponent: "dapr", + EventName: "TestEvent", + }) + require.Error(t, err) + }) + t.Run("raise event workflow - cannot serialize input", func(t *testing.T) { + err := testClient.RaiseEventWorkflowBeta1(ctx, &RaiseEventWorkflowRequest{ + InstanceID: testWorkflowFailureID, + WorkflowComponent: "dapr", + EventName: "TestEvent", + EventData: math.NaN(), + SendRawData: false, + }) + require.Error(t, err) + }) + t.Run("raise event workflow - raw input", func(t *testing.T) { + err := testClient.RaiseEventWorkflowBeta1(ctx, &RaiseEventWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "dapr", + EventName: "TestEvent", + EventData: []byte("teststring"), + SendRawData: true, + }) + require.NoError(t, err) + }) + + t.Run("raise event workflow - raw input (invalid)", func(t *testing.T) { + err := testClient.RaiseEventWorkflowBeta1(ctx, &RaiseEventWorkflowRequest{ + InstanceID: testWorkflowFailureID, + WorkflowComponent: "dapr", + EventName: "TestEvent", + EventData: "test string", + SendRawData: true, + }) + require.Error(t, err) + }) + + // 7: PurgeWorkflow + t.Run("purge workflow", func(t *testing.T) { + err := testClient.PurgeWorkflowBeta1(ctx, &PurgeWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "dapr", + }) + require.NoError(t, err) + }) + + t.Run("purge workflow - valid (without component)", func(t *testing.T) { + err := testClient.PurgeWorkflowBeta1(ctx, &PurgeWorkflowRequest{ + InstanceID: "TestID", + WorkflowComponent: "", + }) + require.NoError(t, err) + }) + + t.Run("purge workflow - invalid instanceid", func(t *testing.T) { + err := testClient.PurgeWorkflowBeta1(ctx, &PurgeWorkflowRequest{ + InstanceID: "", + WorkflowComponent: "dapr", + }) + require.Error(t, err) + }) + + t.Run("purge workflow - grpc failure", func(t *testing.T) { + err := testClient.PurgeWorkflowBeta1(ctx, &PurgeWorkflowRequest{ + InstanceID: testWorkflowFailureID, + WorkflowComponent: "dapr", + }) + require.Error(t, err) + }) +} diff --git a/examples/workflow/README.md b/examples/workflow/README.md new file mode 100644 index 00000000..d962e5e7 --- /dev/null +++ b/examples/workflow/README.md @@ -0,0 +1,83 @@ +# Dapr Workflow Example with go-sdk + +## Step + +### Prepare + +- Dapr installed + +### Run Workflow + + + +```bash +dapr run --app-id workflow \ + --dapr-grpc-port 50001 \ + --log-level debug \ + --resources-path ./config \ + -- go run ./main.go +``` + + + +## Result + +``` + - '== APP == Worker initialized' + - '== APP == TestWorkflow registered' + - '== APP == TestActivity registered' + - '== APP == runner started' + - '== APP == workflow started with id: a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9' + - '== APP == workflow paused' + - '== APP == workflow resumed' + - '== APP == stage: 1' + - '== APP == workflow event raised' + - '== APP == stage: 2' + - '== APP == workflow status: COMPLETED' + - '== APP == workflow purged' + - '== APP == stage: 2' + - '== APP == workflow started with id: a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9' + - '== APP == workflow terminated' + - '== APP == workflow purged' + - '== APP == workflow client test' + - '== APP == [wfclient] started workflow with id: a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9' + - '== APP == [wfclient] workflow status: RUNNING' + - '== APP == [wfclient] stage: 1' + - '== APP == [wfclient] event raised' + - '== APP == [wfclient] stage: 2' + - '== APP == [wfclient] workflow terminated' + - '== APP == [wfclient] workflow purged' + - '== APP == workflow worker successfully shutdown' +``` diff --git a/examples/workflow/config/redis.yaml b/examples/workflow/config/redis.yaml new file mode 100644 index 00000000..5bb57b3f --- /dev/null +++ b/examples/workflow/config/redis.yaml @@ -0,0 +1,14 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: wf-store +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" + - name: actorStateStore + value: "true" diff --git a/examples/workflow/main.go b/examples/workflow/main.go new file mode 100644 index 00000000..99c16407 --- /dev/null +++ b/examples/workflow/main.go @@ -0,0 +1,333 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/dapr/go-sdk/client" + "github.com/dapr/go-sdk/workflow" +) + +var stage = 0 + +const ( + workflowComponent = "dapr" +) + +func main() { + w, err := workflow.NewWorker() + if err != nil { + log.Fatal(err) + } + + fmt.Println("Worker initialized") + + if err := w.RegisterWorkflow(TestWorkflow); err != nil { + log.Fatal(err) + } + fmt.Println("TestWorkflow registered") + + if err := w.RegisterActivity(TestActivity); err != nil { + log.Fatal(err) + } + fmt.Println("TestActivity registered") + + // Start workflow runner + if err := w.Start(); err != nil { + log.Fatal(err) + } + fmt.Println("runner started") + + daprClient, err := client.NewClient() + if err != nil { + log.Fatalf("failed to intialise client: %v", err) + } + defer daprClient.Close() + ctx := context.Background() + + // Start workflow test + respStart, err := daprClient.StartWorkflowBeta1(ctx, &client.StartWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + WorkflowName: "TestWorkflow", + Options: nil, + Input: 1, + SendRawInput: false, + }) + if err != nil { + log.Fatalf("failed to start workflow: %v", err) + } + fmt.Printf("workflow started with id: %v\n", respStart.InstanceID) + + // Pause workflow test + err = daprClient.PauseWorkflowBeta1(ctx, &client.PauseWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + + if err != nil { + log.Fatalf("failed to pause workflow: %v", err) + } + + respGet, err := daprClient.GetWorkflowBeta1(ctx, &client.GetWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + if err != nil { + log.Fatalf("failed to get workflow: %v", err) + } + + if respGet.RuntimeStatus != workflow.StatusSuspended.String() { + log.Fatalf("workflow not paused: %v", respGet.RuntimeStatus) + } + + fmt.Printf("workflow paused\n") + + // Resume workflow test + err = daprClient.ResumeWorkflowBeta1(ctx, &client.ResumeWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + + if err != nil { + log.Fatalf("failed to resume workflow: %v", err) + } + + respGet, err = daprClient.GetWorkflowBeta1(ctx, &client.GetWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + if err != nil { + log.Fatalf("failed to get workflow: %v", err) + } + + if respGet.RuntimeStatus != workflow.StatusRunning.String() { + log.Fatalf("workflow not running") + } + + fmt.Println("workflow resumed") + + fmt.Printf("stage: %d\n", stage) + + // Raise Event Test + + err = daprClient.RaiseEventWorkflowBeta1(ctx, &client.RaiseEventWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + EventName: "testEvent", + EventData: "testData", + SendRawData: false, + }) + + if err != nil { + fmt.Printf("failed to raise event: %v", err) + } + + fmt.Println("workflow event raised") + + time.Sleep(time.Second) // allow workflow to advance + + fmt.Printf("stage: %d\n", stage) + + respGet, err = daprClient.GetWorkflowBeta1(ctx, &client.GetWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + if err != nil { + log.Fatalf("failed to get workflow: %v", err) + } + + fmt.Printf("workflow status: %v\n", respGet.RuntimeStatus) + + // Purge workflow test + err = daprClient.PurgeWorkflowBeta1(ctx, &client.PurgeWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + if err != nil { + log.Fatalf("failed to purge workflow: %v", err) + } + + respGet, err = daprClient.GetWorkflowBeta1(ctx, &client.GetWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + if err != nil && respGet != nil { + log.Fatal("failed to purge workflow") + } + + fmt.Println("workflow purged") + + fmt.Printf("stage: %d\n", stage) + + // Terminate workflow test + respStart, err = daprClient.StartWorkflowBeta1(ctx, &client.StartWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + WorkflowName: "TestWorkflow", + Options: nil, + Input: 1, + SendRawInput: false, + }) + if err != nil { + log.Fatalf("failed to start workflow: %v", err) + } + + fmt.Printf("workflow started with id: %s\n", respStart.InstanceID) + + err = daprClient.TerminateWorkflowBeta1(ctx, &client.TerminateWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + if err != nil { + log.Fatalf("failed to terminate workflow: %v", err) + } + + respGet, err = daprClient.GetWorkflowBeta1(ctx, &client.GetWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + if err != nil { + log.Fatalf("failed to get workflow: %v", err) + } + if respGet.RuntimeStatus != workflow.StatusTerminated.String() { + log.Fatal("failed to terminate workflow") + } + + fmt.Println("workflow terminated") + + err = daprClient.PurgeWorkflowBeta1(ctx, &client.PurgeWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + + respGet, err = daprClient.GetWorkflowBeta1(ctx, &client.GetWorkflowRequest{ + InstanceID: "a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9", + WorkflowComponent: workflowComponent, + }) + if err == nil || respGet != nil { + log.Fatalf("failed to purge workflow: %v", err) + } + + fmt.Println("workflow purged") + + // WFClient + // TODO: Expand client validation + + stage = 0 + fmt.Println("workflow client test") + + wfClient, err := workflow.NewClient() + if err != nil { + log.Fatalf("[wfclient] faield to initialize: %v", err) + } + + id, err := wfClient.ScheduleNewWorkflow(ctx, "TestWorkflow", workflow.WithInstanceID("a7a4168d-3a1c-41da-8a4f-e7f6d9c718d9"), workflow.WithInput(1)) + if err != nil { + log.Fatalf("[wfclient] failed to start workflow: %v", err) + } + + fmt.Printf("[wfclient] started workflow with id: %s\n", id) + + metadata, err := wfClient.FetchWorkflowMetadata(ctx, id) + if err != nil { + log.Fatalf("[wfclient] failed to get worfklow: %v", err) + } + + fmt.Printf("[wfclient] workflow status: %v\n", metadata.RuntimeStatus.String()) + + if stage != 1 { + log.Fatalf("Workflow assertion failed while validating the wfclient. Stage 1 expected, current: %d", stage) + } + + fmt.Printf("[wfclient] stage: %d\n", stage) + + // TODO: WaitForWorkflowStart + // TODO: WaitForWorkflowCompletion + + // raise event + + if err := wfClient.RaiseEvent(ctx, id, "testEvent", workflow.WithEventPayload("testData")); err != nil { + log.Fatalf("[wfclient] failed to raise event: %v", err) + } + + fmt.Println("[wfclient] event raised") + + // Sleep to allow the workflow to advance + time.Sleep(time.Second) + + if stage != 2 { + log.Fatalf("Workflow assertion failed while validating the wfclient. Stage 2 expected, current: %d", stage) + } + + fmt.Printf("[wfclient] stage: %d\n", stage) + + // stop workflow + if err := wfClient.TerminateWorkflow(ctx, id); err != nil { + log.Fatalf("[wfclient] failed to terminate workflow: %v", err) + } + + fmt.Println("[wfclient] workflow terminated") + + if err := wfClient.PurgeWorkflow(ctx, id); err != nil { + log.Fatalf("[wfclient] failed to purge workflow: %v", err) + } + + fmt.Println("[wfclient] workflow purged") + + // stop workflow runtime + if err := w.Shutdown(); err != nil { + log.Fatalf("failed to shutdown runtime: %v", err) + } + + fmt.Println("workflow worker successfully shutdown") +} + +func TestWorkflow(ctx *workflow.WorkflowContext) (any, error) { + var input int + if err := ctx.GetInput(&input); err != nil { + return nil, err + } + var output string + if err := ctx.CallActivity(TestActivity, workflow.ActivityInput(input)).Await(&output); err != nil { + return nil, err + } + + err := ctx.WaitForExternalEvent("testEvent", time.Second*60).Await(&output) + if err != nil { + return nil, err + } + + if err := ctx.CallActivity(TestActivity, workflow.ActivityInput(input)).Await(&output); err != nil { + return nil, err + } + + return output, nil +} + +func TestActivity(ctx workflow.ActivityContext) (any, error) { + var input int + if err := ctx.GetInput(&input); err != nil { + return "", err + } + + stage += input + + return fmt.Sprintf("Stage: %d", stage), nil +} diff --git a/go.mod b/go.mod index 61f094e5..f4da681e 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-chi/chi/v5 v5.0.11 github.com/golang/mock v1.6.0 github.com/google/uuid v1.6.0 + github.com/microsoft/durabletask-go v0.4.1-0.20240122160106-fb5c4c05729d github.com/stretchr/testify v1.8.4 google.golang.org/grpc v1.61.0 google.golang.org/protobuf v1.32.0 @@ -16,11 +17,16 @@ require ( ) require ( + github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/go-logr/logr v1.4.1 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/kr/text v0.2.0 // indirect + github.com/marusama/semaphore/v2 v2.5.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect go.opentelemetry.io/otel v1.23.1 // indirect + go.opentelemetry.io/otel/metric v1.23.1 // indirect go.opentelemetry.io/otel/trace v1.23.1 // indirect golang.org/x/net v0.21.0 // indirect golang.org/x/sys v0.17.0 // indirect diff --git a/go.sum b/go.sum index df4dd2d0..965484a7 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= +github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/dapr/dapr v1.13.0-rc.2 h1:Y5tQ07KB856aSWXxVjb/Lob4AT8Gy/hJxZtwODI21CI= github.com/dapr/dapr v1.13.0-rc.2/go.mod h1:QvxJ5htwv17PeRfFMGkHznEVRkpnt35re7TpF4CsCc8= @@ -5,6 +7,11 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA= github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -21,6 +28,10 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/marusama/semaphore/v2 v2.5.0 h1:o/1QJD9DBYOWRnDhPwDVAXQn6mQYD0gZaS1Tpx6DJGM= +github.com/marusama/semaphore/v2 v2.5.0/go.mod h1:z9nMiNUekt/LTpTUQdpp+4sJeYqUGpwMHfW0Z8V8fnQ= +github.com/microsoft/durabletask-go v0.4.1-0.20240122160106-fb5c4c05729d h1:CVjystOHucBzKExLHD8E96D4KUNbehP0ozgue/6Tq/Y= +github.com/microsoft/durabletask-go v0.4.1-0.20240122160106-fb5c4c05729d/go.mod h1:OSZ4K7SgqBEsaouk3lAVdDzvanIzsdj7angZ0FTeSAU= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= @@ -28,6 +39,8 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opentelemetry.io/otel v1.23.1 h1:Za4UzOqJYS+MUczKI320AtqZHZb7EqxO00jAHE0jmQY= go.opentelemetry.io/otel v1.23.1/go.mod h1:Td0134eafDLcTS4y+zQ26GE8u3dEuRBiBCTUIRHaikA= +go.opentelemetry.io/otel/metric v1.23.1 h1:PQJmqJ9u2QaJLBOELl1cxIdPcpbwzbkjfEyelTl2rlo= +go.opentelemetry.io/otel/metric v1.23.1/go.mod h1:mpG2QPlAfnK8yNhNJAxDZruU9Y1/HubbC+KyH8FaCWI= go.opentelemetry.io/otel/trace v1.23.1 h1:4LrmmEd8AU2rFvU1zegmvqW7+kWarxtNOPyeL6HmYY8= go.opentelemetry.io/otel/trace v1.23.1/go.mod h1:4IpnpJFwr1mo/6HL8XIPJaE9y0+u1KcVmuW7dwFSVrI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/workflow/activity_context.go b/workflow/activity_context.go new file mode 100644 index 00000000..81c60b6a --- /dev/null +++ b/workflow/activity_context.go @@ -0,0 +1,69 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import ( + "context" + "encoding/json" + + "google.golang.org/protobuf/types/known/wrapperspb" + + "github.com/microsoft/durabletask-go/task" +) + +type ActivityContext struct { + ctx task.ActivityContext +} + +func (wfac *ActivityContext) GetInput(v interface{}) error { + return wfac.ctx.GetInput(&v) +} + +func (wfac *ActivityContext) Context() context.Context { + return wfac.ctx.Context() +} + +type callActivityOption func(*callActivityOptions) error + +type callActivityOptions struct { + rawInput *wrapperspb.StringValue +} + +// ActivityInput is an option to pass a JSON-serializable input +func ActivityInput(input any) callActivityOption { + return func(opts *callActivityOptions) error { + data, err := marshalData(input) + if err != nil { + return err + } + opts.rawInput = wrapperspb.String(string(data)) + return nil + } +} + +// ActivityRawInput is an option to pass a byte slice as an input +func ActivityRawInput(input string) callActivityOption { + return func(opts *callActivityOptions) error { + opts.rawInput = wrapperspb.String(input) + return nil + } +} + +func marshalData(input any) ([]byte, error) { + if input == nil { + return nil, nil + } + return json.Marshal(input) +} diff --git a/workflow/activity_context_test.go b/workflow/activity_context_test.go new file mode 100644 index 00000000..0e73e5e7 --- /dev/null +++ b/workflow/activity_context_test.go @@ -0,0 +1,97 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testingTaskActivityContext struct { + inputBytes []byte +} + +func (t *testingTaskActivityContext) GetInput(v any) error { + return json.Unmarshal(t.inputBytes, &v) +} + +func (t *testingTaskActivityContext) Context() context.Context { + return context.TODO() +} + +func TestActivityContext(t *testing.T) { + inputString := "testInputString" + inputBytes, err := json.Marshal(inputString) + require.NoErrorf(t, err, "required no error, but got %v", err) + + ac := ActivityContext{ctx: &testingTaskActivityContext{inputBytes: inputBytes}} + t.Run("test getinput", func(t *testing.T) { + var inputReturn string + err := ac.GetInput(&inputReturn) + require.NoError(t, err) + assert.Equal(t, inputString, inputReturn) + }) + + t.Run("test context", func(t *testing.T) { + assert.Equal(t, context.TODO(), ac.Context()) + }) +} + +func TestCallActivityOptions(t *testing.T) { + t.Run("activity input - valid", func(t *testing.T) { + opts := returnCallActivityOptions(ActivityInput("test")) + assert.Equal(t, "\"test\"", opts.rawInput.GetValue()) + }) + + t.Run("activity input - invalid", func(t *testing.T) { + opts := returnCallActivityOptions(ActivityInput(make(chan int))) + assert.Empty(t, opts.rawInput.GetValue()) + }) + + t.Run("activity raw input - valid", func(t *testing.T) { + opts := returnCallActivityOptions(ActivityRawInput("test")) + assert.Equal(t, "test", opts.rawInput.GetValue()) + }) +} + +func returnCallActivityOptions(opts ...callActivityOption) callActivityOptions { + options := new(callActivityOptions) + for _, configure := range opts { + if err := configure(options); err != nil { + return *options + } + } + return *options +} + +func TestMarshalData(t *testing.T) { + t.Run("test nil input", func(t *testing.T) { + out, err := marshalData(nil) + require.NoError(t, err) + assert.Nil(t, out) + }) + + t.Run("test string input", func(t *testing.T) { + out, err := marshalData("testString") + require.NoError(t, err) + fmt.Println(out) + assert.Equal(t, []byte{0x22, 0x74, 0x65, 0x73, 0x74, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x22}, out) + }) +} diff --git a/workflow/client.go b/workflow/client.go new file mode 100644 index 00000000..2a4d98ca --- /dev/null +++ b/workflow/client.go @@ -0,0 +1,204 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/microsoft/durabletask-go/api" + "github.com/microsoft/durabletask-go/backend" + durabletaskclient "github.com/microsoft/durabletask-go/client" + + dapr "github.com/dapr/go-sdk/client" +) + +type client struct { + taskHubClient *durabletaskclient.TaskHubGrpcClient +} + +// WithInstanceID is an option to set an InstanceID when scheduling a new workflow. +func WithInstanceID(id string) api.NewOrchestrationOptions { + return api.WithInstanceID(api.InstanceID(id)) +} + +// TODO: Implement WithOrchestrationIdReusePolicy + +// WithInput is an option to pass an input when scheduling a new workflow. +func WithInput(input any) api.NewOrchestrationOptions { + return api.WithInput(input) +} + +// WithRawInput is an option to pass a byte slice as an input when scheduling a new workflow. +func WithRawInput(input string) api.NewOrchestrationOptions { + return api.WithRawInput(input) +} + +// WithStartTime is an option to set the start time when scheduling a new workflow. +func WithStartTime(time time.Time) api.NewOrchestrationOptions { + return api.WithStartTime(time) +} + +// WithFetchPayloads is an option to return the payload from a workflow. +func WithFetchPayloads(fetchPayloads bool) api.FetchOrchestrationMetadataOptions { + return api.WithFetchPayloads(fetchPayloads) +} + +// WithEventPayload is an option to send a payload with an event to a workflow. +func WithEventPayload(data any) api.RaiseEventOptions { + return api.WithEventPayload(data) +} + +// WithRawEventData is an option to send a byte slice with an event to a workflow. +func WithRawEventData(data string) api.RaiseEventOptions { + return api.WithRawEventData(data) +} + +// WithOutput is an option to define an output when terminating a workflow. +func WithOutput(data any) api.TerminateOptions { + return api.WithOutput(data) +} + +// WithRawOutput is an option to define a byte slice to output when terminating a workflow. +func WithRawOutput(data string) api.TerminateOptions { + return api.WithRawOutput(data) +} + +type clientOption func(*clientOptions) error + +type clientOptions struct { + daprClient dapr.Client +} + +// WithDaprClient is an option to supply a custom dapr.Client to the workflow client. +func WithDaprClient(input dapr.Client) clientOption { + return func(opt *clientOptions) error { + opt.daprClient = input + return nil + } +} + +// TODO: Implement mocks + +// NewClient returns a workflow client. +func NewClient(opts ...clientOption) (client, error) { + options := new(clientOptions) + for _, configure := range opts { + if err := configure(options); err != nil { + return client{}, fmt.Errorf("failed to load options: %v", err) + } + } + var daprClient dapr.Client + var err error + if options.daprClient == nil { + daprClient, err = dapr.NewClient() + } else { + daprClient = options.daprClient + } + if err != nil { + return client{}, fmt.Errorf("failed to initialise dapr.Client: %v", err) + } + + taskHubClient := durabletaskclient.NewTaskHubGrpcClient(daprClient.GrpcClientConn(), backend.DefaultLogger()) + + return client{ + taskHubClient: taskHubClient, + }, nil +} + +// ScheduleNewWorkflow will start a workflow and return the ID and/or error. +func (c *client) ScheduleNewWorkflow(ctx context.Context, workflow string, opts ...api.NewOrchestrationOptions) (id string, err error) { + if workflow == "" { + return "", errors.New("no workflow specified") + } + workflowID, err := c.taskHubClient.ScheduleNewOrchestration(ctx, workflow, opts...) + return string(workflowID), err +} + +// FetchWorkflowMetadata will return the metadata for a given workflow InstanceID and/or error. +func (c *client) FetchWorkflowMetadata(ctx context.Context, id string, opts ...api.FetchOrchestrationMetadataOptions) (*Metadata, error) { + if id == "" { + return nil, errors.New("no workflow id specified") + } + wfMetadata, err := c.taskHubClient.FetchOrchestrationMetadata(ctx, api.InstanceID(id), opts...) + + return convertMetadata(wfMetadata), err +} + +// WaitForWorkflowStart will wait for a given workflow to start and return metadata and/or an error. +func (c *client) WaitForWorkflowStart(ctx context.Context, id string, opts ...api.FetchOrchestrationMetadataOptions) (*Metadata, error) { + if id == "" { + return nil, errors.New("no workflow id specified") + } + wfMetadata, err := c.taskHubClient.WaitForOrchestrationStart(ctx, api.InstanceID(id), opts...) + + return convertMetadata(wfMetadata), err +} + +// WaitForWorkflowCompletion will block pending the completion of a specified workflow and return the metadata and/or error. +func (c *client) WaitForWorkflowCompletion(ctx context.Context, id string, opts ...api.FetchOrchestrationMetadataOptions) (*Metadata, error) { + if id == "" { + return nil, errors.New("no workflow id specified") + } + wfMetadata, err := c.taskHubClient.WaitForOrchestrationCompletion(ctx, api.InstanceID(id), opts...) + + return convertMetadata(wfMetadata), err +} + +// TerminateWorkflow will stop a given workflow and return an error output. +func (c *client) TerminateWorkflow(ctx context.Context, id string, opts ...api.TerminateOptions) error { + if id == "" { + return errors.New("no workflow id specified") + } + return c.taskHubClient.TerminateOrchestration(ctx, api.InstanceID(id), opts...) +} + +// RaiseEvent will raise an event on a given workflow and return an error output. +func (c *client) RaiseEvent(ctx context.Context, id, eventName string, opts ...api.RaiseEventOptions) error { + if id == "" { + return errors.New("no workflow id specified") + } + if eventName == "" { + return errors.New("no event name specified") + } + return c.taskHubClient.RaiseEvent(ctx, api.InstanceID(id), eventName, opts...) +} + +// SuspendWorkflow will pause a given workflow and return an error output. +func (c *client) SuspendWorkflow(ctx context.Context, id, reason string) error { + if id == "" { + return errors.New("no workflow id specified") + } + return c.taskHubClient.SuspendOrchestration(ctx, api.InstanceID(id), reason) +} + +// ResumeWorkflow will resume a suspended workflow and return an error output. +func (c *client) ResumeWorkflow(ctx context.Context, id, reason string) error { + if id == "" { + return errors.New("no workflow id specified") + } + return c.taskHubClient.ResumeOrchestration(ctx, api.InstanceID(id), reason) +} + +// PurgeWorkflow will purge a given workflow and return an error output. +// NOTE: The workflow must be in a terminated or completed state. +func (c *client) PurgeWorkflow(ctx context.Context, id string) error { + if id == "" { + return errors.New("no workflow id specified") + } + return c.taskHubClient.PurgeOrchestrationState(ctx, api.InstanceID(id)) +} diff --git a/workflow/client_test.go b/workflow/client_test.go new file mode 100644 index 00000000..4e3647d8 --- /dev/null +++ b/workflow/client_test.go @@ -0,0 +1,109 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + daprClient "github.com/dapr/go-sdk/client" +) + +func TestNewClient(t *testing.T) { + // Currently will always fail if no dapr connection available + testClient, err := NewClient() + assert.Empty(t, testClient) + require.Error(t, err) +} + +func TestClientOptions(t *testing.T) { + t.Run("with client", func(t *testing.T) { + opts := returnClientOptions(WithDaprClient(&daprClient.GRPCClient{})) + assert.NotNil(t, opts.daprClient) + }) +} + +func returnClientOptions(opts ...clientOption) clientOptions { + options := new(clientOptions) + for _, configure := range opts { + if err := configure(options); err != nil { + return *options + } + } + return *options +} + +func TestClientMethods(t *testing.T) { + testClient := client{ + taskHubClient: nil, + } + ctx := context.Background() + t.Run("ScheduleNewWorkflow - empty wf name", func(t *testing.T) { + id, err := testClient.ScheduleNewWorkflow(ctx, "", nil) + require.Error(t, err) + assert.Empty(t, id) + }) + + t.Run("FetchWorkflowMetadata - empty id", func(t *testing.T) { + metadata, err := testClient.FetchWorkflowMetadata(ctx, "") + require.Error(t, err) + assert.Nil(t, metadata) + }) + + t.Run("WaitForWorkflowStart - empty id", func(t *testing.T) { + metadata, err := testClient.WaitForWorkflowStart(ctx, "") + require.Error(t, err) + assert.Nil(t, metadata) + }) + + t.Run("WaitForWorkflowCompletion - empty id", func(t *testing.T) { + metadata, err := testClient.WaitForWorkflowCompletion(ctx, "") + require.Error(t, err) + assert.Nil(t, metadata) + }) + + t.Run("TerminateWorkflow - empty id", func(t *testing.T) { + err := testClient.TerminateWorkflow(ctx, "") + require.Error(t, err) + }) + + t.Run("RaiseEvent - empty id", func(t *testing.T) { + err := testClient.RaiseEvent(ctx, "", "EventName") + require.Error(t, err) + }) + + t.Run("RaiseEvent - empty eventName", func(t *testing.T) { + err := testClient.RaiseEvent(ctx, "testID", "") + require.Error(t, err) + }) + + t.Run("SuspendWorkflow - empty id", func(t *testing.T) { + err := testClient.SuspendWorkflow(ctx, "", "reason") + require.Error(t, err) + }) + + t.Run("ResumeWorkflow - empty id", func(t *testing.T) { + err := testClient.ResumeWorkflow(ctx, "", "reason") + require.Error(t, err) + }) + + t.Run("PurgeWorkflow - empty id", func(t *testing.T) { + err := testClient.PurgeWorkflow(ctx, "") + require.Error(t, err) + }) +} diff --git a/workflow/context.go b/workflow/context.go new file mode 100644 index 00000000..7bec4f25 --- /dev/null +++ b/workflow/context.go @@ -0,0 +1,114 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import ( + "fmt" + "time" + + "github.com/microsoft/durabletask-go/task" +) + +type WorkflowContext struct { + orchestrationContext *task.OrchestrationContext +} + +// GetInput casts the input from the context to a specified interface. +func (wfc *WorkflowContext) GetInput(v interface{}) error { + return wfc.orchestrationContext.GetInput(&v) +} + +// Name returns the name string from the workflow context. +func (wfc *WorkflowContext) Name() string { + return wfc.orchestrationContext.Name +} + +// InstanceID returns the ID of the currently executing workflow +func (wfc *WorkflowContext) InstanceID() string { + return fmt.Sprintf("%v", wfc.orchestrationContext.ID) +} + +// CurrentUTCDateTime returns the current workflow time as UTC. Note that this should be used instead of `time.Now()`, which is not compatible with workflow replays. +func (wfc *WorkflowContext) CurrentUTCDateTime() time.Time { + return wfc.orchestrationContext.CurrentTimeUtc +} + +// IsReplaying returns whether the workflow is replaying. +func (wfc *WorkflowContext) IsReplaying() bool { + return wfc.orchestrationContext.IsReplaying +} + +// CallActivity returns a completable task for a given activity. +// You must call Await(output any) on the returned Task to block the workflow and wait for the task to complete. +// The value passed to the Await method must be a pointer or can be nil to ignore the returned value. +// Alternatively, tasks can be awaited using the task.WhenAll or task.WhenAny methods, allowing the workflow +// to block and wait for multiple tasks at the same time. +func (wfc *WorkflowContext) CallActivity(activity interface{}, opts ...callActivityOption) task.Task { + options := new(callActivityOptions) + for _, configure := range opts { + if err := configure(options); err != nil { + return nil + } + } + + return wfc.orchestrationContext.CallActivity(activity, task.WithRawActivityInput(options.rawInput.GetValue())) +} + +// CallChildWorkflow returns a completable task for a given workflow. +// You must call Await(output any) on the returned Task to block the workflow and wait for the task to complete. +// The value passed to the Await method must be a pointer or can be nil to ignore the returned value. +// Alternatively, tasks can be awaited using the task.WhenAll or task.WhenAny methods, allowing the workflow +// to block and wait for multiple tasks at the same time. +func (wfc *WorkflowContext) CallChildWorkflow(workflow interface{}, opts ...callChildWorkflowOption) task.Task { + options := new(callChildWorkflowOptions) + for _, configure := range opts { + if err := configure(options); err != nil { + return nil + } + } + if options.instanceID != "" { + return wfc.orchestrationContext.CallSubOrchestrator(workflow, task.WithRawSubOrchestratorInput(options.rawInput.GetValue()), task.WithSubOrchestrationInstanceID(options.instanceID)) + } + return wfc.orchestrationContext.CallSubOrchestrator(workflow, task.WithRawSubOrchestratorInput(options.rawInput.GetValue())) +} + +// CreateTimer returns a completable task that blocks for a given duration. +// You must call Await(output any) on the returned Task to block the workflow and wait for the task to complete. +// The value passed to the Await method must be a pointer or can be nil to ignore the returned value. +// Alternatively, tasks can be awaited using the task.WhenAll or task.WhenAny methods, allowing the workflow +// to block and wait for multiple tasks at the same time. +func (wfc *WorkflowContext) CreateTimer(duration time.Duration) task.Task { + return wfc.orchestrationContext.CreateTimer(duration) +} + +// WaitForExternalEvent returns a completabel task that waits for a given event to be received. +// You must call Await(output any) on the returned Task to block the workflow and wait for the task to complete. +// The value passed to the Await method must be a pointer or can be nil to ignore the returned value. +// Alternatively, tasks can be awaited using the task.WhenAll or task.WhenAny methods, allowing the workflow +// to block and wait for multiple tasks at the same time. +func (wfc *WorkflowContext) WaitForExternalEvent(eventName string, timeout time.Duration) task.Task { + if eventName == "" { + return nil + } + return wfc.orchestrationContext.WaitForSingleEvent(eventName, timeout) +} + +// ContinueAsNew configures the workflow. +func (wfc *WorkflowContext) ContinueAsNew(newInput any, keepEvents bool) { + if !keepEvents { + wfc.orchestrationContext.ContinueAsNew(newInput) + } + wfc.orchestrationContext.ContinueAsNew(newInput, task.WithKeepUnprocessedEvents()) +} diff --git a/workflow/context_test.go b/workflow/context_test.go new file mode 100644 index 00000000..1332c7b4 --- /dev/null +++ b/workflow/context_test.go @@ -0,0 +1,67 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import ( + "testing" + "time" + + "github.com/microsoft/durabletask-go/task" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestContext(t *testing.T) { + c := WorkflowContext{ + orchestrationContext: &task.OrchestrationContext{ + ID: "test-id", + Name: "test-workflow-context", + IsReplaying: false, + CurrentTimeUtc: time.Date(2023, time.December, 17, 18, 44, 0, 0, time.UTC), + }, + } + t.Run("get input - empty", func(t *testing.T) { + var input string + err := c.GetInput(&input) + require.NoError(t, err) + assert.Equal(t, "", input) + }) + t.Run("workflow name", func(t *testing.T) { + name := c.Name() + assert.Equal(t, "test-workflow-context", name) + }) + t.Run("instance id", func(t *testing.T) { + instanceID := c.InstanceID() + assert.Equal(t, "test-id", instanceID) + }) + t.Run("current utc date time", func(t *testing.T) { + date := c.CurrentUTCDateTime() + assert.Equal(t, time.Date(2023, time.December, 17, 18, 44, 0, 0, time.UTC), date) + }) + t.Run("is replaying", func(t *testing.T) { + replaying := c.IsReplaying() + assert.False(t, replaying) + }) + + t.Run("waitforexternalevent - empty ids", func(t *testing.T) { + completableTask := c.WaitForExternalEvent("", time.Second) + assert.Nil(t, completableTask) + }) + + t.Run("continueasnew", func(t *testing.T) { + c.ContinueAsNew("test", true) + c.ContinueAsNew("test", false) + }) +} diff --git a/workflow/state.go b/workflow/state.go new file mode 100644 index 00000000..969b9dad --- /dev/null +++ b/workflow/state.go @@ -0,0 +1,59 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import "github.com/microsoft/durabletask-go/api" + +type Status int + +const ( + StatusRunning Status = iota + StatusCompleted + StatusContinuedAsNew + StatusFailed + StatusCanceled + StatusTerminated + StatusPending + StatusSuspended + StatusUnknown +) + +// String returns the runtime status as a string. +func (s Status) String() string { + status := [...]string{ + "RUNNING", + "COMPLETED", + "CONTINUED_AS_NEW", + "FAILED", + "CANCELED", + "TERMINATED", + "PENDING", + "SUSPENDED", + } + if s > StatusSuspended || s < StatusRunning { + return "UNKNOWN" + } + return status[s] +} + +type WorkflowState struct { + Metadata api.OrchestrationMetadata +} + +// RuntimeStatus returns the status from a workflow state. +func (wfs *WorkflowState) RuntimeStatus() Status { + s := Status(wfs.Metadata.RuntimeStatus.Number()) + return s +} diff --git a/workflow/state_test.go b/workflow/state_test.go new file mode 100644 index 00000000..459cdc4f --- /dev/null +++ b/workflow/state_test.go @@ -0,0 +1,87 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import ( + "testing" + + "github.com/microsoft/durabletask-go/api" + "github.com/stretchr/testify/assert" +) + +func TestString(t *testing.T) { + wfState := WorkflowState{Metadata: api.OrchestrationMetadata{RuntimeStatus: 0}} + + t.Run("test running", func(t *testing.T) { + s := wfState.RuntimeStatus() + assert.Equal(t, "RUNNING", s.String()) + }) + + wfState.Metadata.RuntimeStatus = 1 + + t.Run("test completed", func(t *testing.T) { + s := wfState.RuntimeStatus() + assert.Equal(t, "COMPLETED", s.String()) + }) + + wfState.Metadata.RuntimeStatus = 2 + + t.Run("test continued_as_new", func(t *testing.T) { + s := wfState.RuntimeStatus() + assert.Equal(t, "CONTINUED_AS_NEW", s.String()) + }) + + wfState.Metadata.RuntimeStatus = 3 + + t.Run("test failed", func(t *testing.T) { + s := wfState.RuntimeStatus() + assert.Equal(t, "FAILED", s.String()) + }) + + wfState.Metadata.RuntimeStatus = 4 + + t.Run("test canceled", func(t *testing.T) { + s := wfState.RuntimeStatus() + assert.Equal(t, "CANCELED", s.String()) + }) + + wfState.Metadata.RuntimeStatus = 5 + + t.Run("test terminated", func(t *testing.T) { + s := wfState.RuntimeStatus() + assert.Equal(t, "TERMINATED", s.String()) + }) + + wfState.Metadata.RuntimeStatus = 6 + + t.Run("test pending", func(t *testing.T) { + s := wfState.RuntimeStatus() + assert.Equal(t, "PENDING", s.String()) + }) + + wfState.Metadata.RuntimeStatus = 7 + + t.Run("test suspended", func(t *testing.T) { + s := wfState.RuntimeStatus() + assert.Equal(t, "SUSPENDED", s.String()) + }) + + wfState.Metadata.RuntimeStatus = 8 + + t.Run("test unknown", func(t *testing.T) { + s := wfState.RuntimeStatus() + assert.Equal(t, "UNKNOWN", s.String()) + }) +} diff --git a/workflow/worker.go b/workflow/worker.go new file mode 100644 index 00000000..94953e99 --- /dev/null +++ b/workflow/worker.go @@ -0,0 +1,164 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import ( + "context" + "errors" + "fmt" + "log" + "reflect" + "runtime" + "strings" + + dapr "github.com/dapr/go-sdk/client" + + "github.com/microsoft/durabletask-go/backend" + durabletaskclient "github.com/microsoft/durabletask-go/client" + "github.com/microsoft/durabletask-go/task" +) + +type WorkflowWorker struct { + tasks *task.TaskRegistry + client *durabletaskclient.TaskHubGrpcClient + + close func() + cancel context.CancelFunc +} + +type Workflow func(ctx *WorkflowContext) (any, error) + +type Activity func(ctx ActivityContext) (any, error) + +type workerOption func(*workerOptions) error + +type workerOptions struct { + daprClient dapr.Client +} + +// WorkerWithDaprClient allows you to specify a custom dapr.Client for the worker. +func WorkerWithDaprClient(input dapr.Client) workerOption { + return func(opts *workerOptions) error { + opts.daprClient = input + return nil + } +} + +// NewWorker returns a worker that can interface with the workflow engine +func NewWorker(opts ...workerOption) (*WorkflowWorker, error) { + options := new(workerOptions) + for _, configure := range opts { + if err := configure(options); err != nil { + return nil, errors.New("failed to load options") + } + } + var daprClient dapr.Client + var err error + if options.daprClient == nil { + daprClient, err = dapr.NewClient() + } else { + daprClient = options.daprClient + } + if err != nil { + return nil, err + } + grpcConn := daprClient.GrpcClientConn() + + return &WorkflowWorker{ + tasks: task.NewTaskRegistry(), + client: durabletaskclient.NewTaskHubGrpcClient(grpcConn, backend.DefaultLogger()), + close: daprClient.Close, + }, nil +} + +// getFunctionName returns the function name as a string +func getFunctionName(f interface{}) (string, error) { + if f == nil { + return "", errors.New("nil function name") + } + + callSplit := strings.Split(runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name(), ".") + + funcName := callSplit[len(callSplit)-1] + + if funcName == "1" { + return "", errors.New("anonymous function name") + } + + return funcName, nil +} + +func wrapWorkflow(w Workflow) task.Orchestrator { + return func(ctx *task.OrchestrationContext) (any, error) { + wfCtx := &WorkflowContext{orchestrationContext: ctx} + return w(wfCtx) + } +} + +// RegisterWorkflow adds a workflow function to the registry +func (ww *WorkflowWorker) RegisterWorkflow(w Workflow) error { + wrappedOrchestration := wrapWorkflow(w) + + // get the function name for the passed workflow + name, err := getFunctionName(w) + if err != nil { + return fmt.Errorf("failed to get workflow decorator: %v", err) + } + + err = ww.tasks.AddOrchestratorN(name, wrappedOrchestration) + return err +} + +func wrapActivity(a Activity) task.Activity { + return func(ctx task.ActivityContext) (any, error) { + aCtx := ActivityContext{ctx: ctx} + + return a(aCtx) + } +} + +// RegisterActivity adds an activity function to the registry +func (ww *WorkflowWorker) RegisterActivity(a Activity) error { + wrappedActivity := wrapActivity(a) + + // get the function name for the passed activity + name, err := getFunctionName(a) + if err != nil { + return fmt.Errorf("failed to get activity decorator: %v", err) + } + + err = ww.tasks.AddActivityN(name, wrappedActivity) + return err +} + +// Start initialises a non-blocking worker to handle workflows and activities registered +// prior to this being called. +func (ww *WorkflowWorker) Start() error { + ctx, cancel := context.WithCancel(context.Background()) + ww.cancel = cancel + if err := ww.client.StartWorkItemListener(ctx, ww.tasks); err != nil { + return fmt.Errorf("failed to start work stream: %v", err) + } + log.Println("work item listener started") + return nil +} + +// Shutdown stops the worker +func (ww *WorkflowWorker) Shutdown() error { + ww.cancel() + ww.close() + log.Println("work item listener shutdown") + return nil +} diff --git a/workflow/worker_test.go b/workflow/worker_test.go new file mode 100644 index 00000000..87589607 --- /dev/null +++ b/workflow/worker_test.go @@ -0,0 +1,117 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import ( + "testing" + + daprClient "github.com/dapr/go-sdk/client" + + "github.com/microsoft/durabletask-go/task" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRuntime(t *testing.T) { + t.Run("failure to create newruntime without dapr", func(t *testing.T) { + wr, err := NewWorker() + require.Error(t, err) + assert.Empty(t, wr) + }) +} + +func TestWorkflowRuntime(t *testing.T) { + testWorker := WorkflowWorker{ + tasks: task.NewTaskRegistry(), + client: nil, + } + + // TODO: Mock grpc conn - currently requires dapr to be available + t.Run("register workflow", func(t *testing.T) { + err := testWorker.RegisterWorkflow(testWorkflow) + require.NoError(t, err) + }) + t.Run("register workflow - anonymous func", func(t *testing.T) { + err := testWorker.RegisterWorkflow(func(ctx *WorkflowContext) (any, error) { + return nil, nil + }) + require.Error(t, err) + }) + t.Run("register activity", func(t *testing.T) { + err := testWorker.RegisterActivity(testActivity) + require.NoError(t, err) + }) + t.Run("register activity - anonymous func", func(t *testing.T) { + err := testWorker.RegisterActivity(func(ctx ActivityContext) (any, error) { + return nil, nil + }) + require.Error(t, err) + }) +} + +func TestWorkerOptions(t *testing.T) { + t.Run("worker client option", func(t *testing.T) { + options := returnWorkerOptions(WorkerWithDaprClient(&daprClient.GRPCClient{})) + assert.NotNil(t, options.daprClient) + }) +} + +func returnWorkerOptions(opts ...workerOption) workerOptions { + options := new(workerOptions) + for _, configure := range opts { + if err := configure(options); err != nil { + return *options + } + } + return *options +} + +func TestWrapWorkflow(t *testing.T) { + t.Run("wrap workflow", func(t *testing.T) { + orchestrator := wrapWorkflow(testWorkflow) + assert.NotNil(t, orchestrator) + }) +} + +func TestWrapActivity(t *testing.T) { + t.Run("wrap activity", func(t *testing.T) { + activity := wrapActivity(testActivity) + assert.NotNil(t, activity) + }) +} + +func TestGetFunctionName(t *testing.T) { + t.Run("get function name", func(t *testing.T) { + name, err := getFunctionName(testWorkflow) + require.NoError(t, err) + assert.Equal(t, "testWorkflow", name) + }) + t.Run("get function name - nil", func(t *testing.T) { + name, err := getFunctionName(nil) + require.Error(t, err) + assert.Equal(t, "", name) + }) +} + +func testWorkflow(ctx *WorkflowContext) (any, error) { + _ = ctx + return nil, nil +} + +func testActivity(ctx ActivityContext) (any, error) { + _ = ctx + return nil, nil +} diff --git a/workflow/workflow.go b/workflow/workflow.go new file mode 100644 index 00000000..0fc9fca6 --- /dev/null +++ b/workflow/workflow.go @@ -0,0 +1,121 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package workflow + +import ( + "fmt" + "time" + + "github.com/microsoft/durabletask-go/api" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +type Metadata struct { + InstanceID string `json:"id"` + Name string `json:"name"` + RuntimeStatus Status `json:"status"` + CreatedAt time.Time `json:"createdAt"` + LastUpdatedAt time.Time `json:"lastUpdatedAt"` + SerializedInput string `json:"serializedInput"` + SerializedOutput string `json:"serializedOutput"` + SerializedCustomStatus string `json:"serializedCustomStatus"` + FailureDetails *FailureDetails `json:"failureDetails"` +} + +type FailureDetails struct { + Type string `json:"type"` + Message string `json:"message"` + StackTrace string `json:"stackTrace"` + InnerFailure *FailureDetails `json:"innerFailure"` + IsNonRetriable bool `json:"IsNonRetriable"` +} + +func convertMetadata(orchestrationMetadata *api.OrchestrationMetadata) *Metadata { + metadata := Metadata{ + InstanceID: string(orchestrationMetadata.InstanceID), + Name: orchestrationMetadata.Name, + RuntimeStatus: Status(orchestrationMetadata.RuntimeStatus.Number()), + CreatedAt: orchestrationMetadata.CreatedAt, + LastUpdatedAt: orchestrationMetadata.LastUpdatedAt, + SerializedInput: orchestrationMetadata.SerializedInput, + SerializedOutput: orchestrationMetadata.SerializedOutput, + SerializedCustomStatus: orchestrationMetadata.SerializedCustomStatus, + } + if orchestrationMetadata.FailureDetails != nil { + metadata.FailureDetails = &FailureDetails{ + Type: orchestrationMetadata.FailureDetails.GetErrorType(), + Message: orchestrationMetadata.FailureDetails.GetErrorMessage(), + StackTrace: orchestrationMetadata.FailureDetails.GetStackTrace().GetValue(), + IsNonRetriable: orchestrationMetadata.FailureDetails.GetIsNonRetriable(), + } + + if orchestrationMetadata.FailureDetails.GetInnerFailure() != nil { + var root *FailureDetails + current := root + failure := orchestrationMetadata.FailureDetails.GetInnerFailure() + for { + current.Type = failure.GetErrorType() + current.Message = failure.GetErrorMessage() + if failure.GetStackTrace() != nil { + current.StackTrace = failure.GetStackTrace().GetValue() + } + if failure.GetInnerFailure() == nil { + break + } + failure = failure.GetInnerFailure() + var inner *FailureDetails + current.InnerFailure = inner + current = inner + } + metadata.FailureDetails.InnerFailure = root + } + } + return &metadata +} + +type callChildWorkflowOptions struct { + instanceID string + rawInput *wrapperspb.StringValue +} + +type callChildWorkflowOption func(*callChildWorkflowOptions) error + +// ChildWorkflowInput is an option to provide a JSON-serializable input when calling a child workflow. +func ChildWorkflowInput(input any) callChildWorkflowOption { + return func(opts *callChildWorkflowOptions) error { + bytes, err := marshalData(input) + if err != nil { + return fmt.Errorf("failed to marshal input data to JSON: %v", err) + } + opts.rawInput = wrapperspb.String(string(bytes)) + return nil + } +} + +// ChildWorkflowRawInput is an option to provide a byte slice input when calling a child workflow. +func ChildWorkflowRawInput(input string) callChildWorkflowOption { + return func(opts *callChildWorkflowOptions) error { + opts.rawInput = wrapperspb.String(input) + return nil + } +} + +// ChildWorkflowInstanceID is an option to provide an instance id when calling a child workflow. +func ChildWorkflowInstanceID(instanceID string) callChildWorkflowOption { + return func(opts *callChildWorkflowOptions) error { + opts.instanceID = instanceID + return nil + } +} diff --git a/workflow/workflow_test.go b/workflow/workflow_test.go new file mode 100644 index 00000000..53354bf3 --- /dev/null +++ b/workflow/workflow_test.go @@ -0,0 +1,50 @@ +package workflow + +import ( + "testing" + + "github.com/microsoft/durabletask-go/api" + "github.com/stretchr/testify/assert" +) + +func TestConvertMetadata(t *testing.T) { + t.Run("convert metadata", func(t *testing.T) { + rawMetadata := &api.OrchestrationMetadata{ + InstanceID: api.InstanceID("test"), + } + metadata := convertMetadata(rawMetadata) + assert.NotEmpty(t, metadata) + }) +} + +func TestCallChildWorkflowOptions(t *testing.T) { + t.Run("child workflow input - valid", func(t *testing.T) { + opts := returnCallChildWorkflowOptions(ChildWorkflowInput("test")) + assert.Equal(t, "\"test\"", opts.rawInput.GetValue()) + }) + + t.Run("child workflow raw input - valid", func(t *testing.T) { + opts := returnCallChildWorkflowOptions(ChildWorkflowRawInput("test")) + assert.Equal(t, "test", opts.rawInput.GetValue()) + }) + + t.Run("child workflow instance id - valid", func(t *testing.T) { + opts := returnCallChildWorkflowOptions(ChildWorkflowInstanceID("test")) + assert.Equal(t, "test", opts.instanceID) + }) + + t.Run("child workflow input - invalid", func(t *testing.T) { + opts := returnCallChildWorkflowOptions(ChildWorkflowInput(make(chan int))) + assert.Empty(t, opts.rawInput.GetValue()) + }) +} + +func returnCallChildWorkflowOptions(opts ...callChildWorkflowOption) callChildWorkflowOptions { + options := new(callChildWorkflowOptions) + for _, configure := range opts { + if err := configure(options); err != nil { + return *options + } + } + return *options +}