@@ -30,6 +30,7 @@ var providerStopSequenceKeyMap = map[string]string{
3030 "amazon" : "stopSequences" ,
3131 "ai21" : "stop_sequences" ,
3232 "cohere" : "stop_sequences" ,
33+ "mistral" : "stop" ,
3334}
3435
3536// BedrockInputOutputAdapter is a helper struct for preparing input and handling output for Bedrock model.
@@ -60,7 +61,7 @@ func (bioa *BedrockInputOutputAdapter) PrepareInput(prompt string, modelParams m
6061 body = modelParams
6162
6263 if _ , ok := body ["max_tokens_to_sample" ]; ! ok {
63- body ["max_tokens_to_sample" ] = 256
64+ body ["max_tokens_to_sample" ] = 1024
6465 }
6566
6667 body ["prompt" ] = fmt .Sprintf ("\n \n Human:%s\n \n Assistant:" , prompt )
@@ -70,6 +71,9 @@ func (bioa *BedrockInputOutputAdapter) PrepareInput(prompt string, modelParams m
7071 case "meta" :
7172 body = modelParams
7273 body ["prompt" ] = prompt
74+ case "mistral" :
75+ body = modelParams
76+ body ["prompt" ] = fmt .Sprintf ("<s>[INST] %s [/INST]" , prompt )
7377 default :
7478 return nil , fmt .Errorf ("unsupported provider: %s" , bioa .provider )
7579 }
@@ -115,6 +119,14 @@ type metaOutput struct {
115119 Generation string `json:"generation"`
116120}
117121
122+ // mistralOutput is a struct representing the output structure for the "mistral" provider.
123+ type mistralOutput struct {
124+ Outputs []struct {
125+ Text string `json:"text"`
126+ StopReason string `json:"stop_reason"`
127+ } `json:"outputs"`
128+ }
129+
118130// PrepareOutput prepares the output for the Bedrock model based on the specified provider.
119131func (bioa * BedrockInputOutputAdapter ) PrepareOutput (response []byte ) (string , error ) {
120132 switch bioa .provider {
@@ -153,6 +165,13 @@ func (bioa *BedrockInputOutputAdapter) PrepareOutput(response []byte) (string, e
153165 }
154166
155167 return output .Generation , nil
168+ case "mistral" :
169+ output := & mistralOutput {}
170+ if err := json .Unmarshal (response , output ); err != nil {
171+ return "" , err
172+ }
173+
174+ return output .Outputs [0 ].Text , nil
156175 }
157176
158177 return "" , fmt .Errorf ("unsupported provider: %s" , bioa .provider )
@@ -185,6 +204,14 @@ type metaStreamOutput struct {
185204 Generation string `json:"generation"`
186205}
187206
207+ // mistralStreamOutput is a struct representing the stream output structure for the "mistral" provider.
208+ type mistralStreamOutput struct {
209+ Outputs []struct {
210+ Text string `json:"text"`
211+ StopReason string `json:"stop_reason"`
212+ } `json:"outputs"`
213+ }
214+
188215// PrepareStreamOutput prepares the output for the Bedrock model based on the specified provider.
189216func (bioa * BedrockInputOutputAdapter ) PrepareStreamOutput (response []byte ) (string , error ) {
190217 switch bioa .provider {
@@ -217,6 +244,13 @@ func (bioa *BedrockInputOutputAdapter) PrepareStreamOutput(response []byte) (str
217244 }
218245
219246 return output .Generation , nil
247+ case "mistral" :
248+ output := & mistralStreamOutput {}
249+ if err := json .Unmarshal (response , output ); err != nil {
250+ return "" , err
251+ }
252+
253+ return output .Outputs [0 ].Text , nil
220254 }
221255
222256 return "" , fmt .Errorf ("unsupported provider: %s" , bioa .provider )
@@ -549,6 +583,67 @@ func NewBedrockMeta(client BedrockRuntimeClient, optFns ...func(o *BedrockMetaOp
549583 })
550584}
551585
586+ type BedrockMistralOptions struct {
587+ * schema.CallbackOptions `map:"-"`
588+ schema.Tokenizer `map:"-"`
589+
590+ // Model id to use.
591+ ModelID string `map:"model_id,omitempty"`
592+
593+ // Temperature controls the randomness of text generation. Higher values make it more random.
594+ Temperature float32 `map:"temperature"`
595+
596+ // TopP is the total probability mass of tokens to consider at each step.
597+ TopP float32 `map:"top_p,omitempty"`
598+
599+ // TopK determines how the model selects tokens for output.
600+ TopK int `map:"top_k"`
601+
602+ // MaxTokens sets the maximum number of tokens in the generated text.
603+ MaxTokens int `json:"max_tokens,omitempty"`
604+
605+ // Stream indicates whether to stream the results or not.
606+ Stream bool `map:"stream,omitempty"`
607+ }
608+
609+ func NewBedrockMistral (client BedrockRuntimeClient , optFns ... func (o * BedrockMistralOptions )) (* Bedrock , error ) {
610+ opts := BedrockMistralOptions {
611+ CallbackOptions : & schema.CallbackOptions {
612+ Verbose : golc .Verbose ,
613+ },
614+ ModelID : "mistral.mistral-7b-instruct-v0:2" , //https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
615+ Temperature : 0.5 ,
616+ TopP : 0.9 ,
617+ TopK : 200 ,
618+ MaxTokens : 512 ,
619+ }
620+
621+ for _ , fn := range optFns {
622+ fn (& opts )
623+ }
624+
625+ if opts .Tokenizer == nil {
626+ var tErr error
627+
628+ opts .Tokenizer , tErr = tokenizer .NewGPT2 ()
629+ if tErr != nil {
630+ return nil , tErr
631+ }
632+ }
633+
634+ return NewBedrock (client , opts .ModelID , func (o * BedrockOptions ) {
635+ o .CallbackOptions = opts .CallbackOptions
636+ o .Tokenizer = opts .Tokenizer
637+ o .ModelParams = map [string ]any {
638+ "temperature" : opts .Temperature ,
639+ "top_p" : opts .TopP ,
640+ "top_k" : opts .TopK ,
641+ "max_tokens" : opts .MaxTokens ,
642+ }
643+ o .Stream = opts .Stream
644+ })
645+ }
646+
552647// BedrockOptions contains options for configuring the Bedrock LLM model.
553648type BedrockOptions struct {
554649 * schema.CallbackOptions `map:"-"`
0 commit comments