Skip to content

Commit

Permalink
feat: integrated text embeddings into application
Browse files Browse the repository at this point in the history
feat: integrated text embeddings into application
  • Loading branch information
telpirion authored Dec 6, 2024
2 parents f6e8015 + 2a948b4 commit 0eb6b7b
Show file tree
Hide file tree
Showing 14 changed files with 204 additions and 51 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ COPY prompts ./server/templates
COPY server/favicon.ico ./server/favicon.ico
COPY server/generated ./server/generated
COPY server/ai ./server/ai
COPY server/databases ./server/databases
COPY server/*.go ./server

COPY server/go.mod server/go.sum ./server/
Expand Down
25 changes: 14 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ These models have been evaluated against the following set of metrics.
+ [Groundedness][groundedness]
+ [Coherence][coherence]

The following table shows the evaluation scores for each of these models.
The following table shows the evaluation scores for each of these models. Change
from previous evaluation runs provided in parentheses.

| Model | ROUGE | Closed domain | Open domain | Groundedness | Coherence | Date of eval |
| ---------------- | ------ | ------------- | ----------- | ------------ | --------- | ------------ |
| Gemini 1.5 Flash | 0.20[1]| 0.0 | 1.0 | 1.0[1] | 3.3 | 2024-11-25 |
| Tuned Gemini | 0.21 | 0.4 | 1.0 | 1.0 | 2.4 | 2024-11-25 |
| Gemma | 0.05 | 0.6 | 0.4 | 0.8 | 1.4 | 2024-11-25 |
| Model | ROUGE | Closed domain | Open domain | Groundedness | Coherence | Date of eval |
| ---------------------| ------------ | ------------- | ----------- | ------------ | ---------- | ------------ |
| Gemini 1.5 Flash [1] | 0.35 (+0.15) | 0.56 (+0.56) | 1.0 (0.0) | 1.0 (0.0) | 3.3 (-0.3) | 2024-11-27 |
| Tuned Gemini | 0.26 (+0.05) | 0.6 (+0.2) | 1.0 (0.0) | 0.8 (-0.2) | 3.2 (+0.4) | 2024-11-27 |
| Gemma | 0.10 (+0.05) | 0.9 (+0.3) | 0.8 (+0.4) | 0.8 (0.0) | 2.2 (+0.8) | 2024-11-27 |
| Reddit-agent Gemini | 0.11 | 1.0 | 0.8. | 0.2 | 1.8 | 2024-11-27 |

[1]: Gemini 1.5 Flash responses from 2024-11-05 are used as the ground truth
for all other models.
Expand All @@ -72,11 +74,12 @@ techniques.

The following table shows the evaluation scores for adversarial prompting.

| Model | Prompt injection | Prompt leaking | Jailbreaking | Date of eval |
| ---------------- | ----------------- | -------------- | ------------ | ------------ |
| Gemini 1.5 Flash | 0.66 | 0.66 | 1.0 | 2024-11-25 |
| Tuned Gemini | 0.33 | 1.0 | 1.0 | 2024-11-25 |
| Gemma | 1.0 | 0.66 | 0.66 | 2024-11-25 |
| Model | Prompt injection | Prompt leaking | Jailbreaking | Date of eval |
| ------------------- | ----------------- | -------------- | ------------ | ------------ |
| Gemini 1.5 Flash | FAIL | FAIL | PASS | 2024-11-27 |
| Tuned Gemini | FAIL | PASS | PASS | 2024-11-27 |
| Gemma | FAIL | FAIL | PASS | 2024-11-27 |
| Reddit-agent Gemini | PASS | PASS | FAIL | 2024-11-27 |

[bigquery]: https://cloud.google.com/bigquery/docs
[bulma]: https://bulma.io/documentation/components/message/
Expand Down
109 changes: 102 additions & 7 deletions server/ai/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"google.golang.org/api/option"
"google.golang.org/protobuf/types/known/structpb"

db "github.com/telpirion/MyHerodotus/databases"
"github.com/telpirion/MyHerodotus/generated"
)

Expand All @@ -28,7 +29,9 @@ const (
GemmaTemplate = "templates/gemma.2024.10.25.tmpl"
GeminiModel = "gemini-1.5-flash-001"
HistoryTemplate = "templates/conversation_history.tmpl"
EmbeddingModelName = "text-embedding-005"
MaxGemmaTokens int32 = 2048
location = "us-west1"
)

var (
Expand Down Expand Up @@ -81,12 +84,14 @@ func Predict(query, modality, projectID string) (response string, templateName s
response, err = textPredictGemini(query, projectID, GeminiTuned)
case AgentAssisted:
response, err = textPredictWithReddit(query, projectID)
case EmbeddingsAssisted:
response, err = textPredictWithEmbeddings(query, projectID)
default:
response, err = textPredictGemini(query, projectID, Gemini)
}

if err != nil {
return "", "", nil
return "", "", err
}

cachedContext += fmt.Sprintf("### Human: %s\n### Assistant: %s\n", query, response)
Expand All @@ -95,7 +100,6 @@ func Predict(query, modality, projectID string) (response string, templateName s

// GetTokenCount uses the Gemini tokenizer to count the tokens in some text.
func GetTokenCount(text, projectID string) (int32, error) {
location := "us-west1"
ctx := context.Background()
client, err := genai.NewClient(ctx, projectID, location)
if err != nil {
Expand Down Expand Up @@ -135,7 +139,6 @@ func StoreConversationContext(conversationHistory []generated.ConversationBit, p
}

ctx := context.Background()
location := "us-west1"
client, err := genai.NewClient(ctx, projectID, location)
if err != nil {
return "", fmt.Errorf("unable to create client: %w", err)
Expand Down Expand Up @@ -208,7 +211,6 @@ func createPrompt(message, templateName, history string) (string, error) {
// textPredictGemma2 generates text using a Gemma2 hosted model
func textPredictGemma(message, projectID string) (string, error) {
ctx := context.Background()
location := "us-west1"
endpointID := os.Getenv("ENDPOINT_ID")
gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)

Expand Down Expand Up @@ -265,8 +267,6 @@ func textPredictGemma(message, projectID string) (string, error) {
// textPredictGemini generates text using a Gemini 1.5 Flash model
func textPredictGemini(message, projectID string, modality Modality) (string, error) {
ctx := context.Background()
location := "us-west1"

client, err := genai.NewClient(ctx, projectID, location)
if err != nil {
return "", err
Expand Down Expand Up @@ -296,7 +296,7 @@ func textPredictGemini(message, projectID string, modality Modality) (string, er

candidate, err := getCandidate(resp)
if err != nil {
return "I'm not sure how to answer that. Would you please repeat the question?", nil
return "", nil
}
return extractAnswer(candidate), nil
}
Expand Down Expand Up @@ -397,3 +397,98 @@ func textPredictWithReddit(query, projectID string) (string, error) {
output := string(res.Candidates[0].Content.Parts[0].(genai.Text))
return output, nil
}

func textPredictWithEmbeddings(query, projectID string) (string, error) {

queryEmbed, err := getQueryTextEmbedding(query, projectID)
if err != nil {
return "", err
}

// Get context from embeddings
embeddingContext, err := db.GetEmbedding(queryEmbed, projectID)
if err != nil {
return "", err
}

ctx := context.Background()
client, err := genai.NewClient(ctx, projectID, location)
if err != nil {
return "", err
}
defer client.Close()

llm := client.GenerativeModel(GeminiModel)

createPrompt(query, GeminiTemplate, embeddingContext)

resp, err := llm.GenerateContent(ctx, genai.Text(query))
if err != nil {
fmt.Println(err.Error())
return "", err
}

candidate, err := getCandidate(resp)
if err != nil {
return "", err
}
return extractAnswer(candidate), nil
}

func getQueryTextEmbedding(query, projectID string) ([]float32, error) {

var embedding []float32
ctx := context.Background()

apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
dimensionality := 128
texts := []string{query}

client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
if err != nil {
return embedding, err
}
defer client.Close()

endpoint := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s",
projectID, location, EmbeddingModelName)
instances := make([]*structpb.Value, len(texts))
for i, text := range texts {
instances[i] = structpb.NewStructValue(&structpb.Struct{
Fields: map[string]*structpb.Value{
"content": structpb.NewStringValue(text),
"task_type": structpb.NewStringValue("RETRIEVAL_QUERY"),
},
})
}

params := structpb.NewStructValue(&structpb.Struct{
Fields: map[string]*structpb.Value{
"outputDimensionality": structpb.NewNumberValue(float64(dimensionality)),
},
})

req := &aiplatformpb.PredictRequest{
Endpoint: endpoint,
Instances: instances,
Parameters: params,
}
resp, err := client.Predict(ctx, req)
if err != nil {
return embedding, err
}
embeddings := make([][]float32, len(resp.Predictions))
for i, prediction := range resp.Predictions {
values := prediction.GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
embeddings[i] = make([]float32, len(values))
for j, value := range values {
embeddings[i][j] = float32(value.GetNumberValue())
}
}

if len(embeddings) == 0 {
return embedding, fmt.Errorf("vertex: text embeddings: no embeddings created")
}

return embeddings[0], nil
}
6 changes: 4 additions & 2 deletions server/ai/vertex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package ai
import (
"strings"
"testing"

"github.com/telpirion/MyHerodotus/generated"
)

func TestCreatePrompt(t *testing.T) {
query := "I'm a query"
got, err := createPrompt(query, GeminiTemplate)
got, err := createPrompt(query, GeminiTemplate, "")
if err != nil {
t.Fatal(err)
}
Expand All @@ -18,7 +20,7 @@ func TestCreatePrompt(t *testing.T) {
}

func TestSetConversationContext(t *testing.T) {
convoHistory := []ConversationBit{
convoHistory := []generated.ConversationBit{
{
UserQuery: "test user query",
BotResponse: "test bot response",
Expand Down
14 changes: 5 additions & 9 deletions server/db.go → server/databases/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ Herodotus [
]
*/
package main
package databases

import (
"context"
"fmt"
"os"

"cloud.google.com/go/firestore"
Expand Down Expand Up @@ -60,7 +59,7 @@ type ConversationHistory struct {
Conversations []generated.ConversationBit
}

func saveConversation(convo generated.ConversationBit, userEmail, projectID string) (string, error) {
func SaveConversation(convo generated.ConversationBit, userEmail, projectID string) (string, error) {
ctx := context.Background()

// Get CollectionName for running in staging or prod
Expand All @@ -71,7 +70,6 @@ func saveConversation(convo generated.ConversationBit, userEmail, projectID stri

client, err := firestore.NewClientWithDatabase(ctx, projectID, DBName)
if err != nil {
LogError(fmt.Sprintf("firestore.Client: %v\n", err))
return "", err
}
defer client.Close()
Expand All @@ -83,7 +81,7 @@ func saveConversation(convo generated.ConversationBit, userEmail, projectID stri
return docRef.ID, err
}

func updateConversation(documentId, userEmail, rating, projectID string) error {
func UpdateConversation(documentId, userEmail, rating, projectID string) error {

// Get CollectionName for running in staging or prod
_collectionName, ok := os.LookupEnv("COLLECTION_NAME")
Expand All @@ -106,12 +104,11 @@ func updateConversation(documentId, userEmail, rating, projectID string) error {
return nil
}

func getConversation(userEmail, projectID string) ([]generated.ConversationBit, error) {
func GetConversation(userEmail, projectID string) ([]generated.ConversationBit, error) {
ctx := context.Background()
conversations := []generated.ConversationBit{}
client, err := firestore.NewClientWithDatabase(ctx, projectID, DBName)
if err != nil {
LogError(fmt.Sprintf("firestore.Client: %v\n", err))
return conversations, err
}
defer client.Close()
Expand All @@ -125,13 +122,12 @@ func getConversation(userEmail, projectID string) ([]generated.ConversationBit,
break
}
if err != nil {
LogError(fmt.Sprintf("Firestore Iterator: %v\n", err))
return conversations, err
}
var convo generated.ConversationBit
err = doc.DataTo(&convo)
if err != nil {
LogError(fmt.Sprintf("Firestore document unmarshaling: %v\n", err))

continue
}
conversations = append(conversations, convo)
Expand Down
25 changes: 13 additions & 12 deletions server/db_test.go → server/databases/db_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package databases

import (
"context"
Expand All @@ -8,6 +8,7 @@ import (
"time"

"cloud.google.com/go/firestore"
"github.com/telpirion/MyHerodotus/generated"
)

var (
Expand All @@ -29,18 +30,18 @@ func TestMain(m *testing.M) {

collection := client.Collection(_collectionName)
subcollection := collection.Doc(email2).Collection(SubCollectionName)
_, _, err = subcollection.Add(ctx, ConversationBit{
_, _, err = subcollection.Add(ctx, generated.ConversationBit{
UserQuery: "test user query",
BotResponse: "test bot response",
Created: time.Now(),
Created: time.Now().Unix(),
})
if err != nil {
log.Fatal(err)
}
_, _, err = subcollection.Add(ctx, ConversationBit{
_, _, err = subcollection.Add(ctx, generated.ConversationBit{
UserQuery: "test user query 2",
BotResponse: "test bot response 2",
Created: time.Now(),
Created: time.Now().Unix(),
})
if err != nil {
log.Fatal(err)
Expand All @@ -61,12 +62,12 @@ func TestMain(m *testing.M) {
}

func TestSaveConversation(t *testing.T) {
convo := &ConversationBit{
convo := &generated.ConversationBit{
UserQuery: "This is from unit test",
BotResponse: "This is a bot response",
Created: time.Now(),
Created: time.Now().Unix(),
}
id, err := saveConversation(*convo, email, _projectID)
id, err := SaveConversation(*convo, email, _projectID)
if err != nil {
t.Fatal(err)
}
Expand All @@ -75,12 +76,12 @@ func TestSaveConversation(t *testing.T) {
t.Error("Empty document ID")
}

nextConvo := &ConversationBit{
nextConvo := &generated.ConversationBit{
UserQuery: "This is also from a unit test",
BotResponse: "This is another fake bot response",
Created: time.Now(),
Created: time.Now().Unix(),
}
nextID, err := saveConversation(*nextConvo, email, _projectID)
nextID, err := SaveConversation(*nextConvo, email, _projectID)
if err != nil {
t.Fatalf("Error on adding next conversation: %v\n\n", err)
}
Expand All @@ -90,7 +91,7 @@ func TestSaveConversation(t *testing.T) {
}

func TestGetConversation(t *testing.T) {
conversations, err := getConversation(email2, projectID)
conversations, err := GetConversation(email2, _projectID)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading

0 comments on commit 0eb6b7b

Please sign in to comment.