Skip to content

Commit

Permalink
feat: oci genai chat models
Browse files Browse the repository at this point in the history
Signed-off-by: Anders Swanson <[email protected]>
  • Loading branch information
anders-swanson committed Feb 11, 2025
1 parent fcc8563 commit 49c44a3
Showing 1 changed file with 107 additions and 28 deletions.
135 changes: 107 additions & 28 deletions pkg/ai/ocigenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,32 @@ package ai
import (
"context"
"errors"
"fmt"
"github.com/oracle/oci-go-sdk/v65/common"
"github.com/oracle/oci-go-sdk/v65/generativeai"
"github.com/oracle/oci-go-sdk/v65/generativeaiinference"
"strings"
"reflect"
)

const ociClientName = "oci"

type ociModelVendor string

const (
vendorCohere = "cohere"
vendorMeta = "meta"
)

type OCIGenAIClient struct {
nopCloser

client *generativeaiinference.GenerativeAiInferenceClient
model string
model *generativeai.Model
modelID string
compartmentId string
temperature float32
topP float32
topK int32
maxTokens int
}

Expand All @@ -40,9 +51,10 @@ func (c *OCIGenAIClient) GetName() string {

func (c *OCIGenAIClient) Configure(config IAIConfig) error {
config.GetEndpointName()
c.model = config.GetModel()
c.modelID = config.GetModel()

Check warning on line 54 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L54

Added line #L54 was not covered by tests
c.temperature = config.GetTemperature()
c.topP = config.GetTopP()
c.topK = config.GetTopK()

Check warning on line 57 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L57

Added line #L57 was not covered by tests
c.maxTokens = config.GetMaxTokens()
c.compartmentId = config.GetCompartmentId()
provider := common.DefaultConfigProvider()
Expand All @@ -55,43 +67,110 @@ func (c *OCIGenAIClient) Configure(config IAIConfig) error {
}

func (c *OCIGenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
generateTextRequest := c.newGenerateTextRequest(prompt)
generateTextResponse, err := c.client.GenerateText(ctx, generateTextRequest)
request := c.newChatRequest(prompt)
response, err := c.client.Chat(ctx, request)

Check warning on line 71 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L70-L71

Added lines #L70 - L71 were not covered by tests
if err != nil {
return "", err
}
return extractGeneratedText(generateTextResponse.InferenceResponse)
if err != nil {
return "", err
}
return extractGeneratedText(response.ChatResponse)

Check warning on line 78 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L75-L78

Added lines #L75 - L78 were not covered by tests
}

func (c *OCIGenAIClient) newGenerateTextRequest(prompt string) generativeaiinference.GenerateTextRequest {
func (c *OCIGenAIClient) newChatRequest(prompt string) generativeaiinference.ChatRequest {
return generativeaiinference.ChatRequest{
ChatDetails: generativeaiinference.ChatDetails{
CompartmentId: &c.compartmentId,
ServingMode: c.getServingMode(),
ChatRequest: c.getChatModelRequest(prompt),
},
}

Check warning on line 88 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L81-L88

Added lines #L81 - L88 were not covered by tests
}

func (c *OCIGenAIClient) getChatModelRequest(prompt string) generativeaiinference.BaseChatRequest {

Check warning on line 91 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L91

Added line #L91 was not covered by tests
temperatureF64 := float64(c.temperature)
topPF64 := float64(c.topP)
return generativeaiinference.GenerateTextRequest{
GenerateTextDetails: generativeaiinference.GenerateTextDetails{
CompartmentId: &c.compartmentId,
ServingMode: generativeaiinference.OnDemandServingMode{
ModelId: &c.model,
},
InferenceRequest: generativeaiinference.CohereLlmInferenceRequest{
Prompt: &prompt,
MaxTokens: &c.maxTokens,
Temperature: &temperatureF64,
TopP: &topPF64,
topK := int(c.topP)

switch c.getVendor() {
case vendorMeta:
messages := []generativeaiinference.Message{
generativeaiinference.UserMessage{
Content: []generativeaiinference.ChatContent{
generativeaiinference.TextContent{
Text: &prompt,
},
},

Check warning on line 104 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L94-L104

Added lines #L94 - L104 were not covered by tests
},
},
}
return generativeaiinference.GenericChatRequest{
Messages: messages,
TopK: &topK,
TopP: &topPF64,
Temperature: &temperatureF64,
MaxTokens: &c.maxTokens,
}
default: // Default to cohere
return generativeaiinference.CohereChatRequest{
Message: &prompt,
MaxTokens: &c.maxTokens,
Temperature: &temperatureF64,
TopK: &topK,
TopP: &topPF64,
}

Check warning on line 121 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L106-L121

Added lines #L106 - L121 were not covered by tests

}
}

func extractGeneratedText(llmInferenceResponse generativeaiinference.LlmInferenceResponse) (string, error) {
response, ok := llmInferenceResponse.(generativeaiinference.CohereLlmInferenceResponse)
if !ok {
return "", errors.New("failed to extract generated text from backed response")
func extractGeneratedText(llmInferenceResponse generativeaiinference.BaseChatResponse) (string, error) {
switch response := llmInferenceResponse.(type) {
case generativeaiinference.GenericChatResponse:
if len(response.Choices) > 0 && len(response.Choices[0].Message.GetContent()) > 0 {
if content, ok := response.Choices[0].Message.GetContent()[0].(generativeaiinference.TextContent); ok {
return *content.Text, nil
}

Check warning on line 132 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L126-L132

Added lines #L126 - L132 were not covered by tests
}
return "", errors.New("no text found in oci response")
case generativeaiinference.CohereChatResponse:
return *response.Text, nil
default:
return "", fmt.Errorf("unknown oci response type: %s", reflect.TypeOf(llmInferenceResponse).Name())

Check warning on line 138 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L134-L138

Added lines #L134 - L138 were not covered by tests
}
sb := strings.Builder{}
for _, text := range response.GeneratedTexts {
if text.Text != nil {
sb.WriteString(*text.Text)
}

func (c *OCIGenAIClient) getServingMode() generativeaiinference.ServingMode {
if c.isBaseModel() {
return generativeaiinference.OnDemandServingMode{
ModelId: &c.modelID,

Check warning on line 145 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L142-L145

Added lines #L142 - L145 were not covered by tests
}
}
return sb.String(), nil
return generativeaiinference.DedicatedServingMode{
EndpointId: &c.modelID,
}

Check warning on line 150 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L148-L150

Added lines #L148 - L150 were not covered by tests
}

func (c *OCIGenAIClient) getModel(provider common.ConfigurationProvider) (*generativeai.Model, error) {

Check failure on line 153 in pkg/ai/ocigenai.go

View workflow job for this annotation

GitHub Actions / golangci-lint

[golangci] reported by reviewdog 🐶 func `(*OCIGenAIClient).getModel` is unused (unused) Raw Output: pkg/ai/ocigenai.go:153:26: func `(*OCIGenAIClient).getModel` is unused (unused) func (c *OCIGenAIClient) getModel(provider common.ConfigurationProvider) (*generativeai.Model, error) { ^
client, err := generativeai.NewGenerativeAiClientWithConfigurationProvider(provider)
if err != nil {
return nil, err
}
response, err := client.GetModel(context.Background(), generativeai.GetModelRequest{
ModelId: &c.modelID,
})
if err != nil {
return nil, err
}
return &response.Model, nil

Check warning on line 164 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L153-L164

Added lines #L153 - L164 were not covered by tests
}

func (c *OCIGenAIClient) isBaseModel() bool {
return c.model != nil && c.model.Type == generativeai.ModelTypeBase

Check warning on line 168 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L167-L168

Added lines #L167 - L168 were not covered by tests
}

func (c *OCIGenAIClient) getVendor() ociModelVendor {
if c.model == nil || c.model.Vendor == nil {
return ""
}
return ociModelVendor(*c.model.Vendor)

Check warning on line 175 in pkg/ai/ocigenai.go

View check run for this annotation

Codecov / codecov/patch

pkg/ai/ocigenai.go#L171-L175

Added lines #L171 - L175 were not covered by tests
}

0 comments on commit 49c44a3

Please sign in to comment.