Skip to content

Commit 8309cff

Browse files
committed
Add bedrock mistral llm support
1 parent 5a473bc commit 8309cff

File tree

6 files changed

+221
-2
lines changed

6 files changed

+221
-2
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
8+
"github.com/aws/aws-sdk-go-v2/config"
9+
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
10+
"github.com/hupe1980/golc/model"
11+
"github.com/hupe1980/golc/model/llm"
12+
"github.com/hupe1980/golc/prompt"
13+
)
14+
15+
func main() {
16+
cfg, _ := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
17+
client := bedrockruntime.NewFromConfig(cfg)
18+
19+
bedrock, err := llm.NewBedrockAmazon(client, func(o *llm.BedrockAmazonOptions) {
20+
o.Temperature = 0.3
21+
})
22+
if err != nil {
23+
log.Fatal(err)
24+
}
25+
26+
res, err := model.GeneratePrompt(context.Background(), bedrock, prompt.StringPromptValue("Hello ai!"))
27+
if err != nil {
28+
log.Fatal(err)
29+
}
30+
31+
fmt.Println(res.Generations[0].Text)
32+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"log"
6+
7+
"github.com/aws/aws-sdk-go-v2/config"
8+
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
9+
"github.com/hupe1980/golc/callback"
10+
"github.com/hupe1980/golc/model"
11+
"github.com/hupe1980/golc/model/llm"
12+
"github.com/hupe1980/golc/prompt"
13+
"github.com/hupe1980/golc/schema"
14+
)
15+
16+
func main() {
17+
cfg, _ := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
18+
client := bedrockruntime.NewFromConfig(cfg)
19+
20+
bedrock, err := llm.NewBedrockAmazon(client, func(o *llm.BedrockAmazonOptions) {
21+
o.Callbacks = []schema.Callback{callback.NewStreamWriterHandler()}
22+
o.Stream = true
23+
})
24+
if err != nil {
25+
log.Fatal(err)
26+
}
27+
28+
if _, err := model.GeneratePrompt(context.Background(), bedrock, prompt.StringPromptValue("Write me a song about sparkling water.")); err != nil {
29+
log.Fatal(err)
30+
}
31+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
8+
"github.com/aws/aws-sdk-go-v2/config"
9+
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
10+
"github.com/hupe1980/golc/model"
11+
"github.com/hupe1980/golc/model/llm"
12+
"github.com/hupe1980/golc/prompt"
13+
)
14+
15+
func main() {
16+
cfg, _ := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
17+
client := bedrockruntime.NewFromConfig(cfg)
18+
19+
bedrock, err := llm.NewBedrockMistral(client)
20+
if err != nil {
21+
log.Fatal(err)
22+
}
23+
24+
res, err := model.GeneratePrompt(context.Background(), bedrock, prompt.StringPromptValue("Tell me a joke"))
25+
if err != nil {
26+
log.Fatal(err)
27+
}
28+
29+
fmt.Println(res.Generations[0].Text)
30+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"log"
6+
7+
"github.com/aws/aws-sdk-go-v2/config"
8+
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
9+
"github.com/hupe1980/golc/callback"
10+
"github.com/hupe1980/golc/model"
11+
"github.com/hupe1980/golc/model/llm"
12+
"github.com/hupe1980/golc/prompt"
13+
"github.com/hupe1980/golc/schema"
14+
)
15+
16+
func main() {
17+
cfg, _ := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
18+
client := bedrockruntime.NewFromConfig(cfg)
19+
20+
bedrock, err := llm.NewBedrockMistral(client, func(o *llm.BedrockMistralOptions) {
21+
o.Callbacks = []schema.Callback{callback.NewStreamWriterHandler()}
22+
o.Stream = true
23+
})
24+
if err != nil {
25+
log.Fatal(err)
26+
}
27+
28+
if _, err := model.GeneratePrompt(context.Background(), bedrock, prompt.StringPromptValue("Write me a song about sparkling water.")); err != nil {
29+
log.Fatal(err)
30+
}
31+
}

model/llm/bedrock.go

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ var providerStopSequenceKeyMap = map[string]string{
3030
"amazon": "stopSequences",
3131
"ai21": "stop_sequences",
3232
"cohere": "stop_sequences",
33+
"mistral": "stop",
3334
}
3435

3536
// BedrockInputOutputAdapter is a helper struct for preparing input and handling output for Bedrock model.
@@ -60,7 +61,7 @@ func (bioa *BedrockInputOutputAdapter) PrepareInput(prompt string, modelParams m
6061
body = modelParams
6162

6263
if _, ok := body["max_tokens_to_sample"]; !ok {
63-
body["max_tokens_to_sample"] = 256
64+
body["max_tokens_to_sample"] = 1024
6465
}
6566

6667
body["prompt"] = fmt.Sprintf("\n\nHuman:%s\n\nAssistant:", prompt)
@@ -70,6 +71,9 @@ func (bioa *BedrockInputOutputAdapter) PrepareInput(prompt string, modelParams m
7071
case "meta":
7172
body = modelParams
7273
body["prompt"] = prompt
74+
case "mistral":
75+
body = modelParams
76+
body["prompt"] = fmt.Sprintf("<s>[INST] %s [/INST]", prompt)
7377
default:
7478
return nil, fmt.Errorf("unsupported provider: %s", bioa.provider)
7579
}
@@ -115,6 +119,14 @@ type metaOutput struct {
115119
Generation string `json:"generation"`
116120
}
117121

122+
// mistralOutput is a struct representing the output structure for the "mistral" provider.
123+
type mistralOutput struct {
124+
Outputs []struct {
125+
Text string `json:"text"`
126+
StopReason string `json:"stop_reason"`
127+
} `json:"outputs"`
128+
}
129+
118130
// PrepareOutput prepares the output for the Bedrock model based on the specified provider.
119131
func (bioa *BedrockInputOutputAdapter) PrepareOutput(response []byte) (string, error) {
120132
switch bioa.provider {
@@ -153,6 +165,13 @@ func (bioa *BedrockInputOutputAdapter) PrepareOutput(response []byte) (string, e
153165
}
154166

155167
return output.Generation, nil
168+
case "mistral":
169+
output := &mistralOutput{}
170+
if err := json.Unmarshal(response, output); err != nil {
171+
return "", err
172+
}
173+
174+
return output.Outputs[0].Text, nil
156175
}
157176

158177
return "", fmt.Errorf("unsupported provider: %s", bioa.provider)
@@ -185,6 +204,14 @@ type metaStreamOutput struct {
185204
Generation string `json:"generation"`
186205
}
187206

207+
// mistralStreamOutput is a struct representing the stream output structure for the "mistral" provider.
208+
type mistralStreamOutput struct {
209+
Outputs []struct {
210+
Text string `json:"text"`
211+
StopReason string `json:"stop_reason"`
212+
} `json:"outputs"`
213+
}
214+
188215
// PrepareStreamOutput prepares the output for the Bedrock model based on the specified provider.
189216
func (bioa *BedrockInputOutputAdapter) PrepareStreamOutput(response []byte) (string, error) {
190217
switch bioa.provider {
@@ -217,6 +244,13 @@ func (bioa *BedrockInputOutputAdapter) PrepareStreamOutput(response []byte) (str
217244
}
218245

219246
return output.Generation, nil
247+
case "mistral":
248+
output := &mistralStreamOutput{}
249+
if err := json.Unmarshal(response, output); err != nil {
250+
return "", err
251+
}
252+
253+
return output.Outputs[0].Text, nil
220254
}
221255

222256
return "", fmt.Errorf("unsupported provider: %s", bioa.provider)
@@ -549,6 +583,67 @@ func NewBedrockMeta(client BedrockRuntimeClient, optFns ...func(o *BedrockMetaOp
549583
})
550584
}
551585

586+
type BedrockMistralOptions struct {
587+
*schema.CallbackOptions `map:"-"`
588+
schema.Tokenizer `map:"-"`
589+
590+
// Model id to use.
591+
ModelID string `map:"model_id,omitempty"`
592+
593+
// Temperature controls the randomness of text generation. Higher values make it more random.
594+
Temperature float32 `map:"temperature"`
595+
596+
// TopP is the total probability mass of tokens to consider at each step.
597+
TopP float32 `map:"top_p,omitempty"`
598+
599+
// TopK determines how the model selects tokens for output.
600+
TopK int `map:"top_k"`
601+
602+
// MaxTokens sets the maximum number of tokens in the generated text.
603+
MaxTokens int `json:"max_tokens,omitempty"`
604+
605+
// Stream indicates whether to stream the results or not.
606+
Stream bool `map:"stream,omitempty"`
607+
}
608+
609+
func NewBedrockMistral(client BedrockRuntimeClient, optFns ...func(o *BedrockMistralOptions)) (*Bedrock, error) {
610+
opts := BedrockMistralOptions{
611+
CallbackOptions: &schema.CallbackOptions{
612+
Verbose: golc.Verbose,
613+
},
614+
ModelID: "mistral.mistral-7b-instruct-v0:2", //https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
615+
Temperature: 0.5,
616+
TopP: 0.9,
617+
TopK: 200,
618+
MaxTokens: 512,
619+
}
620+
621+
for _, fn := range optFns {
622+
fn(&opts)
623+
}
624+
625+
if opts.Tokenizer == nil {
626+
var tErr error
627+
628+
opts.Tokenizer, tErr = tokenizer.NewGPT2()
629+
if tErr != nil {
630+
return nil, tErr
631+
}
632+
}
633+
634+
return NewBedrock(client, opts.ModelID, func(o *BedrockOptions) {
635+
o.CallbackOptions = opts.CallbackOptions
636+
o.Tokenizer = opts.Tokenizer
637+
o.ModelParams = map[string]any{
638+
"temperature": opts.Temperature,
639+
"top_p": opts.TopP,
640+
"top_k": opts.TopK,
641+
"max_tokens": opts.MaxTokens,
642+
}
643+
o.Stream = opts.Stream
644+
})
645+
}
646+
552647
// BedrockOptions contains options for configuring the Bedrock LLM model.
553648
type BedrockOptions struct {
554649
*schema.CallbackOptions `map:"-"`

model/llm/bedrock_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func TestBedrockInputOutputAdapter(t *testing.T) {
4747
modelParams: map[string]interface{}{
4848
"param1": "value1",
4949
},
50-
expectedBody: `{"param1":"value1","max_tokens_to_sample":256,"prompt":"\n\nHuman:Test prompt\n\nAssistant:"}`,
50+
expectedBody: `{"param1":"value1","max_tokens_to_sample":1024,"prompt":"\n\nHuman:Test prompt\n\nAssistant:"}`,
5151
expectedErr: "",
5252
},
5353
{

0 commit comments

Comments
 (0)