From c738dbe775f9c32f24beb0dc8ba1cd1e083a7c99 Mon Sep 17 00:00:00 2001 From: mikeee Date: Tue, 5 Nov 2024 18:28:24 +0000 Subject: [PATCH] feat: conversation api implementation Signed-off-by: mikeee --- client/client.go | 3 + client/conversation.go | 132 ++++++++++++++++++ .../config/conversation-echo.yaml | 7 + examples/conversation/main.go | 36 +++++ 4 files changed, 178 insertions(+) create mode 100644 client/conversation.go create mode 100644 examples/conversation/config/conversation-echo.yaml create mode 100644 examples/conversation/main.go diff --git a/client/client.go b/client/client.go index 950a4626..998089db 100644 --- a/client/client.go +++ b/client/client.go @@ -259,6 +259,9 @@ type Client interface { // DeleteJobAlpha1 deletes a scheduled job. DeleteJobAlpha1(ctx context.Context, name string) error + // ConverseAlpha1 interacts with a conversational AI model. + ConverseAlpha1(ctx context.Context, componentName string, inputs []ConversationInput, options ...conversationRequestOption) (*ConversationResponse, error) + // GrpcClient returns the base grpc client if grpc is used and nil otherwise GrpcClient() pb.DaprClient diff --git a/client/conversation.go b/client/conversation.go new file mode 100644 index 00000000..0118b079 --- /dev/null +++ b/client/conversation.go @@ -0,0 +1,132 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "fmt" + runtimev1pb "github.com/dapr/dapr/pkg/proto/runtime/v1" + "google.golang.org/protobuf/types/known/anypb" +) + +type conversationRequestOptions struct { + Parameters map[string]*anypb.Any + Metadata map[string]string + ContextID *string + ScrubPII *bool // Scrub PII from the output + Temperature *float64 +} + +type conversationRequestOption func(request *conversationRequestOptions) + +type ConversationInput struct { + Message string + Role *string + ScrubPII *bool // Scrub PII from the input +} + +type ConversationInputOption func(*ConversationInput) + +func NewConversationInput(message string, opts ...ConversationInputOption) ConversationInput { + return ConversationInput{} +} + +type ConversationResponse struct { + ContextID string + Outputs []ConversationResult +} + +type ConversationResult struct { + Result string + Parameters map[string]*anypb.Any +} + +func WithParameters(parameters map[string]*anypb.Any) conversationRequestOption { + return func(o *conversationRequestOptions) { + o.Parameters = parameters + } +} + +func WithMetadata(metadata map[string]string) conversationRequestOption { + return func(o *conversationRequestOptions) { + o.Metadata = metadata + } +} + +func WithContextID(id string) conversationRequestOption { + return func(o *conversationRequestOptions) { + o.ContextID = &id + } +} + +func WithScrubPII(scrub bool) conversationRequestOption { + return func(o *conversationRequestOptions) { + o.ScrubPII = &scrub + } +} + +func WithTemperature(temp float64) conversationRequestOption { + return func(o *conversationRequestOptions) { + o.Temperature = &temp + } +} + +func (c *GRPCClient) ConverseAlpha1(ctx context.Context, componentName string, inputs []ConversationInput, options ...conversationRequestOption) (*ConversationResponse, error) { + + var cinputs []*runtimev1pb.ConversationInput + for _, i := range inputs { + cinputs = append(cinputs, &runtimev1pb.ConversationInput{ + Message: i.Message, + Role: i.Role, + ScrubPII: i.ScrubPII, + }) + } + + var o conversationRequestOptions + for _, opt := range options { + if opt != nil { + opt(&o) + } + } + + request := runtimev1pb.ConversationRequest{ + Name: componentName, + ContextID: o.ContextID, + Inputs: cinputs, + Parameters: o.Parameters, + Metadata: o.Metadata, + ScrubPII: o.ScrubPII, + Temperature: o.Temperature, + } + + fmt.Println("invoking") + + resp, err := c.protoClient.ConverseAlpha1(ctx, &request) + if err != nil { + return nil, err + } + + var outputs []ConversationResult + for _, i := range resp.GetOutputs() { + outputs = append(outputs, ConversationResult{ + Result: i.GetResult(), + Parameters: i.GetParameters(), + }) + } + + return &ConversationResponse{ + ContextID: resp.GetContextID(), + Outputs: outputs, + }, nil +} diff --git a/examples/conversation/config/conversation-echo.yaml b/examples/conversation/config/conversation-echo.yaml new file mode 100644 index 00000000..9a8b3072 --- /dev/null +++ b/examples/conversation/config/conversation-echo.yaml @@ -0,0 +1,7 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: echo +spec: + type: conversation.echo + version: v1 \ No newline at end of file diff --git a/examples/conversation/main.go b/examples/conversation/main.go new file mode 100644 index 00000000..70f1fd15 --- /dev/null +++ b/examples/conversation/main.go @@ -0,0 +1,36 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package main + +import ( + "context" + "fmt" + dapr "github.com/dapr/go-sdk/client" + "log" +) + +func main() { + client, err := dapr.NewClientWithPort("47649") + if err != nil { + panic(err) + } + + resp, err := client.ConverseAlpha1(context.Background(), "echo", []dapr.ConversationInput{{Message: "hello"}}) + if err != nil { + log.Fatalf("err: %v", err) + } + + fmt.Println(resp.Outputs) +}