Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for the conversation API #646

Merged
merged 12 commits into from
Nov 27, 2024
5 changes: 3 additions & 2 deletions .github/workflows/validate_examples.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ jobs:
GOARCH: amd64
GOPROXY: https://proxy.golang.org
DAPR_INSTALL_URL: https://raw.githubusercontent.com/dapr/cli/master/install/install.sh
DAPR_CLI_REF: ${{ github.event.inputs.daprcli_commit }}
DAPR_REF: ${{ github.event.inputs.daprdapr_commit }}
DAPR_CLI_REF: 8bf3a1605f7b2ecfa7d4633ce4c5de13cdb65c5e
DAPR_REF: c86a77f6db5fb9f294f39d096ff0d9a053e55982
CHECKOUT_REPO: ${{ github.repository }}
CHECKOUT_REF: ${{ github.ref }}
outputs:
Expand Down Expand Up @@ -164,6 +164,7 @@ jobs:
[
"actor",
"configuration",
"conversation",
"crypto",
"dist-scheduler",
"grpc-service",
Expand Down
3 changes: 3 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ type Client interface {
// DeleteJobAlpha1 deletes a scheduled job.
DeleteJobAlpha1(ctx context.Context, name string) error

// ConverseAlpha1 interacts with a conversational AI model.
ConverseAlpha1(ctx context.Context, request conversationRequest, options ...conversationRequestOption) (*ConversationResponse, error)

// GrpcClient returns the base grpc client if grpc is used and nil otherwise
GrpcClient() pb.DaprClient

Expand Down
146 changes: 146 additions & 0 deletions client/conversation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package client

import (
"context"

"google.golang.org/protobuf/types/known/anypb"

runtimev1pb "github.com/dapr/dapr/pkg/proto/runtime/v1"
)

// conversationRequest object - currently unexported as used in a functions option pattern
type conversationRequest struct {
name string
inputs []ConversationInput
Parameters map[string]*anypb.Any
Metadata map[string]string
ContextID *string
ScrubPII *bool // Scrub PII from the output
Temperature *float64
}

// NewConversationRequest defines a request with a component name and one or more inputs as a slice
func NewConversationRequest(llmName string, inputs []ConversationInput) conversationRequest {
return conversationRequest{
name: llmName,
inputs: inputs,
}

Check warning on line 40 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L36-L40

Added lines #L36 - L40 were not covered by tests
}

type conversationRequestOption func(request *conversationRequest)

// ConversationInput defines a single input.
type ConversationInput struct {
// The string to send to the llm.
Message string
// The role of the message.
Role *string
// Whether to Scrub PII from the input
ScrubPII *bool
}

// ConversationResponse is the basic response from a conversationRequest.
type ConversationResponse struct {
ContextID string
Outputs []ConversationResult
}

// ConversationResult is the individual
type ConversationResult struct {
Result string
Parameters map[string]*anypb.Any
}

// WithParameters should be used to provide parameters for custom fields.
func WithParameters(parameters map[string]*anypb.Any) conversationRequestOption {
return func(o *conversationRequest) {
o.Parameters = parameters
}

Check warning on line 71 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L68-L71

Added lines #L68 - L71 were not covered by tests
}

// WithMetadata used to define metadata to be passed to components.
func WithMetadata(metadata map[string]string) conversationRequestOption {
return func(o *conversationRequest) {
o.Metadata = metadata
}

Check warning on line 78 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L75-L78

Added lines #L75 - L78 were not covered by tests
}

// WithContextID to provide a new context or continue an existing one.
func WithContextID(id string) conversationRequestOption {
return func(o *conversationRequest) {
o.ContextID = &id
}

Check warning on line 85 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L82-L85

Added lines #L82 - L85 were not covered by tests
}

// WithScrubPII to define whether the outputs should have PII removed.
func WithScrubPII(scrub bool) conversationRequestOption {
return func(o *conversationRequest) {
o.ScrubPII = &scrub
}

Check warning on line 92 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L89-L92

Added lines #L89 - L92 were not covered by tests
}

// WithTemperature to specify which way the LLM leans.
func WithTemperature(temp float64) conversationRequestOption {
return func(o *conversationRequest) {
o.Temperature = &temp
}

Check warning on line 99 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L96-L99

Added lines #L96 - L99 were not covered by tests
}

// ConverseAlpha1 can invoke an LLM given a request created by the NewConversationRequest function.
func (c *GRPCClient) ConverseAlpha1(ctx context.Context, req conversationRequest, options ...conversationRequestOption) (*ConversationResponse, error) {
cinputs := make([]*runtimev1pb.ConversationInput, len(req.inputs))
for i, in := range req.inputs {
cinputs[i] = &runtimev1pb.ConversationInput{
Message: in.Message,
Role: in.Role,
ScrubPII: in.ScrubPII,
}
}

Check warning on line 111 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L103-L111

Added lines #L103 - L111 were not covered by tests

for _, opt := range options {
if opt != nil {
opt(&req)
}

Check warning on line 116 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L113-L116

Added lines #L113 - L116 were not covered by tests
}

request := runtimev1pb.ConversationRequest{
Name: req.name,
ContextID: req.ContextID,
Inputs: cinputs,
Parameters: req.Parameters,
Metadata: req.Metadata,
ScrubPII: req.ScrubPII,
Temperature: req.Temperature,
}

resp, err := c.protoClient.ConverseAlpha1(ctx, &request)
if err != nil {
return nil, err
}

Check warning on line 132 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L119-L132

Added lines #L119 - L132 were not covered by tests

outputs := make([]ConversationResult, len(resp.GetOutputs()))
for i, o := range resp.GetOutputs() {
outputs[i] = ConversationResult{
Result: o.GetResult(),
Parameters: o.GetParameters(),
}
}

Check warning on line 140 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L134-L140

Added lines #L134 - L140 were not covered by tests

return &ConversationResponse{
ContextID: resp.GetContextID(),
Outputs: outputs,
}, nil

Check warning on line 145 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L142-L145

Added lines #L142 - L145 were not covered by tests
}
36 changes: 36 additions & 0 deletions examples/conversation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Dapr Conversation Example with go-sdk

## Step

### Prepare

- Dapr installed

### Run Conversation Example

<!-- STEP
name: Run Conversation
output_match_mode: substring
expected_stdout_lines:
- '== APP == conversation output: hello world'

background: true
sleep: 60
timeout_seconds: 60
-->

```bash
dapr run --app-id conversation \
--dapr-grpc-port 50001 \
--log-level debug \
--resources-path ./config \
-- go run ./main.go
```

<!-- END_STEP -->

## Result

```
- '== APP == conversation output: hello world'
```
7 changes: 7 additions & 0 deletions examples/conversation/config/conversation-echo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
apiVersion: dapr.io/v1alpha1
kind: Component
metadata:
name: echo
spec:
type: conversation.echo
version: v1
48 changes: 48 additions & 0 deletions examples/conversation/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main

import (
"context"
"fmt"
dapr "github.com/dapr/go-sdk/client"
"log"
)

func main() {
client, err := dapr.NewClient()
if err != nil {
panic(err)
}

input := dapr.ConversationInput{
Message: "hello world",
// Role: nil, // Optional
// ScrubPII: nil, // Optional
}

fmt.Printf("conversation input: %s\n", input.Message)

var conversationComponent = "echo"

request := dapr.NewConversationRequest(conversationComponent, []dapr.ConversationInput{input})

resp, err := client.ConverseAlpha1(context.Background(), request)
if err != nil {
log.Fatalf("err: %v", err)
}

fmt.Printf("conversation output: %s\n", resp.Outputs[0].Result)
}
18 changes: 9 additions & 9 deletions examples/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/dapr/go-sdk v0.0.0-00010101000000-000000000000
github.com/go-redis/redis/v8 v8.11.5
github.com/google/uuid v1.6.0
google.golang.org/grpc v1.65.0
google.golang.org/grpc v1.67.0
google.golang.org/grpc/examples v0.0.0-20240516203910-e22436abb809
google.golang.org/protobuf v1.34.2
)
Expand All @@ -18,7 +18,7 @@ require (
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dapr/dapr v1.14.1 // indirect
github.com/dapr/dapr v1.14.5-0.20241120233620-c86a77f6db5f // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/go-chi/chi/v5 v5.1.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
Expand All @@ -28,12 +28,12 @@ require (
github.com/marusama/semaphore/v2 v2.5.0 // indirect
github.com/microsoft/durabletask-go v0.5.1-0.20241024170039-0c4afbc95428 // indirect
github.com/xhit/go-str2duration/v2 v2.1.0 // indirect
go.opentelemetry.io/otel v1.27.0 // indirect
go.opentelemetry.io/otel/metric v1.27.0 // indirect
go.opentelemetry.io/otel/trace v1.27.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
go.opentelemetry.io/otel v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.30.0 // indirect
golang.org/x/net v0.29.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
40 changes: 20 additions & 20 deletions examples/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyY
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/dapr/dapr v1.14.1 h1:n+FGF82caTsBjmnmKdBfrO94GRuLeuYs6qrAN5oG4ZM=
github.com/dapr/dapr v1.14.1/go.mod h1:oDNgaPHQIDZ3G4n4g89TElXWgkluYwcar41DI/oF4gw=
github.com/dapr/dapr v1.14.5-0.20241120233620-c86a77f6db5f h1:wXPHK2o5FIABU5BvKk/21MN6GKaoUvWc7fESH/hwVls=
github.com/dapr/dapr v1.14.5-0.20241120233620-c86a77f6db5f/go.mod h1:WlsLcudco11+BhaIvg2XyGxD+2GcZf8OTOawd94dAQs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down Expand Up @@ -59,24 +59,24 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc=
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
go.opentelemetry.io/otel v1.27.0 h1:9BZoF3yMK/O1AafMiQTVu0YDj5Ea4hPhxCs7sGva+cg=
go.opentelemetry.io/otel v1.27.0/go.mod h1:DMpAK8fzYRzs+bi3rS5REupisuqTheUlSZJ1WnZaPAQ=
go.opentelemetry.io/otel/metric v1.27.0 h1:hvj3vdEKyeCi4YaYfNjv2NUje8FqKqUY8IlF0FxV/ik=
go.opentelemetry.io/otel/metric v1.27.0/go.mod h1:mVFgmRlhljgBiuk/MP/oKylr4hs85GZAylncepAX/ak=
go.opentelemetry.io/otel/trace v1.27.0 h1:IqYb813p7cmbHk0a5y6pD5JPakbVfftRXABGt5/Rscw=
go.opentelemetry.io/otel/trace v1.27.0/go.mod h1:6RiD1hkAprV4/q+yd2ln1HG9GoPx39SuvvstaLBl+l4=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts=
go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc=
go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w=
go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ=
go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc=
go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61 h1:N9BgCIAUvn/M+p4NJccWPWb3BWh88+zyL0ll9HgbEeM=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/grpc v1.67.0 h1:IdH9y6PF5MPSdAntIcpjQ+tXO41pcQsfZV2RxtQgVcw=
google.golang.org/grpc v1.67.0/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
google.golang.org/grpc/examples v0.0.0-20240516203910-e22436abb809 h1:f96Rv5C5Y2CWlbKK6KhKDdyFgGOjPHPEMsdyaxE9k0c=
google.golang.org/grpc/examples v0.0.0-20240516203910-e22436abb809/go.mod h1:uaPEAc5V00jjG3DPhGFLXGT290RUV3+aNQigs1W50/8=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
Expand Down
18 changes: 9 additions & 9 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ module github.com/dapr/go-sdk
go 1.23.3

require (
github.com/dapr/dapr v1.14.1
github.com/dapr/dapr v1.14.5-0.20241120233620-c86a77f6db5f
github.com/go-chi/chi/v5 v5.1.0
github.com/golang/mock v1.6.0
github.com/google/uuid v1.6.0
github.com/microsoft/durabletask-go v0.5.1-0.20241024170039-0c4afbc95428
github.com/stretchr/testify v1.9.0
google.golang.org/grpc v1.65.0
google.golang.org/grpc v1.67.0
google.golang.org/protobuf v1.34.2
gopkg.in/yaml.v3 v3.0.1
)
Expand All @@ -23,12 +23,12 @@ require (
github.com/kr/text v0.2.0 // indirect
github.com/marusama/semaphore/v2 v2.5.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
go.opentelemetry.io/otel v1.27.0 // indirect
go.opentelemetry.io/otel/metric v1.27.0 // indirect
go.opentelemetry.io/otel/trace v1.27.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
go.opentelemetry.io/otel v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.30.0 // indirect
golang.org/x/net v0.29.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)
Loading