From cf61e15d872e1b5e1e2e4b7dec03c302ababe934 Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Wed, 27 Nov 2024 22:27:55 +0000 Subject: [PATCH 1/6] feat: updates to evaluation service --- README.md | 25 ++++++++++++++----------- services/evaluations/Dockerfile | 1 + services/evaluations/cloudbuild.yaml | 4 ++-- services/evaluations/evaluations.py | 1 + services/evaluations/herodotus_model.py | 2 +- 5 files changed, 19 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index dd2157a..bb6bef1 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,15 @@ These models have been evaluated against the following set of metrics. + [Groundedness][groundedness] + [Coherence][coherence] -The following table shows the evaluation scores for each of these models. +The following table shows the evaluation scores for each of these models. Change +from previous evaluation runs provided in parentheses. -| Model | ROUGE | Closed domain | Open domain | Groundedness | Coherence | Date of eval | -| ---------------- | ------ | ------------- | ----------- | ------------ | --------- | ------------ | -| Gemini 1.5 Flash | 0.20[1]| 0.0 | 1.0 | 1.0[1] | 3.3 | 2024-11-25 | -| Tuned Gemini | 0.21 | 0.4 | 1.0 | 1.0 | 2.4 | 2024-11-25 | -| Gemma | 0.05 | 0.6 | 0.4 | 0.8 | 1.4 | 2024-11-25 | +| Model | ROUGE | Closed domain | Open domain | Groundedness | Coherence | Date of eval | +| ---------------------| ------------ | ------------- | ----------- | ------------ | ---------- | ------------ | +| Gemini 1.5 Flash [1] | 0.35 (+0.15) | 0.56 (+0.56) | 1.0 (0.0) | 1.0 (0.0) | 3.3 (-0.3) | 2024-11-27 | +| Tuned Gemini | 0.26 (+0.05) | 0.6 (+0.2) | 1.0 (0.0) | 0.8 (-0.2) | 3.2 (+0.4) | 2024-11-27 | +| Gemma | 0.10 (+0.05) | 0.9 (+0.3) | 0.8 (+0.4) | 0.8 (0.0) | 2.2 (+0.8) | 2024-11-27 | +| Reddit-agent Gemini | 0.11 | 1.0 | 0.8. | 0.2 | 1.8 | 2024-11-27 | [1]: Gemini 1.5 Flash responses from 2024-11-05 are used as the ground truth for all other models. @@ -72,11 +74,12 @@ techniques. The following table shows the evaluation scores for adversarial prompting. -| Model | Prompt injection | Prompt leaking | Jailbreaking | Date of eval | -| ---------------- | ----------------- | -------------- | ------------ | ------------ | -| Gemini 1.5 Flash | 0.66 | 0.66 | 1.0 | 2024-11-25 | -| Tuned Gemini | 0.33 | 1.0 | 1.0 | 2024-11-25 | -| Gemma | 1.0 | 0.66 | 0.66 | 2024-11-25 | +| Model | Prompt injection | Prompt leaking | Jailbreaking | Date of eval | +| ------------------- | ----------------- | -------------- | ------------ | ------------ | +| Gemini 1.5 Flash | FAIL | FAIL | PASS | 2024-11-27 | +| Tuned Gemini | FAIL | PASS | PASS | 2024-11-27 | +| Gemma | FAIL | FAIL | PASS | 2024-11-27 | +| Reddit-agent Gemini | PASS | PASS | FAIL | 2024-11-27 | [bigquery]: https://cloud.google.com/bigquery/docs [bulma]: https://bulma.io/documentation/components/message/ diff --git a/services/evaluations/Dockerfile b/services/evaluations/Dockerfile index 1690bc9..5237e70 100644 --- a/services/evaluations/Dockerfile +++ b/services/evaluations/Dockerfile @@ -5,6 +5,7 @@ WORKDIR / COPY evaluations.py ./evaluations.py COPY metrics.py ./metrics.py COPY prompts.py ./prompts.py +COPY herodotus_model.py ./herodotus_model.py COPY requirements.txt ./requirements.txt RUN pip install -r requirements.txt diff --git a/services/evaluations/cloudbuild.yaml b/services/evaluations/cloudbuild.yaml index 5b67f36..b3c7673 100644 --- a/services/evaluations/cloudbuild.yaml +++ b/services/evaluations/cloudbuild.yaml @@ -3,7 +3,7 @@ steps: env: - 'DATASET_NAME=myherodotus' script: | - docker build -t us-west1-docker.pkg.dev/$PROJECT_ID/my-herodotus/evaluations:v0.2.0 . + docker build -t us-west1-docker.pkg.dev/$PROJECT_ID/my-herodotus/evaluations:v0.3.0 . automapSubstitutions: true images: -- 'us-west1-docker.pkg.dev/$PROJECT_ID/my-herodotus/evaluations:v0.2.0' \ No newline at end of file +- 'us-west1-docker.pkg.dev/$PROJECT_ID/my-herodotus/evaluations:v0.3.0' \ No newline at end of file diff --git a/services/evaluations/evaluations.py b/services/evaluations/evaluations.py index 3428a5e..eb88e3c 100644 --- a/services/evaluations/evaluations.py +++ b/services/evaluations/evaluations.py @@ -64,6 +64,7 @@ def main(): ("gemini", "gemini_1_5_flash_001"), ("gemini-tuned", "tuned_gemini"), ("gemma", "gemma"), + ("agent-assisted", "agent-assisted") ] for m in models: model_id, model_name = m diff --git a/services/evaluations/herodotus_model.py b/services/evaluations/herodotus_model.py index 5d7a7a8..f63421d 100644 --- a/services/evaluations/herodotus_model.py +++ b/services/evaluations/herodotus_model.py @@ -16,7 +16,7 @@ def __init__(self, modality): @property def _model_name(self) -> str: - return "gemini_1_5_flash_001" + return self.modality def generate_content(self, prompt: str): payload = {"message": prompt, "model": self.modality} From 300f3b4f5b1d75d9a167181c6b2d8f5059af213f Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Mon, 2 Dec 2024 13:07:07 -0800 Subject: [PATCH 2/6] fix: minor issues --- server/ai/vertex_test.go | 6 ++++-- server/db_test.go | 17 +++++++++-------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/server/ai/vertex_test.go b/server/ai/vertex_test.go index d9a755b..f17dcfe 100644 --- a/server/ai/vertex_test.go +++ b/server/ai/vertex_test.go @@ -3,11 +3,13 @@ package ai import ( "strings" "testing" + + "github.com/telpirion/MyHerodotus/generated" ) func TestCreatePrompt(t *testing.T) { query := "I'm a query" - got, err := createPrompt(query, GeminiTemplate) + got, err := createPrompt(query, GeminiTemplate, "") if err != nil { t.Fatal(err) } @@ -18,7 +20,7 @@ func TestCreatePrompt(t *testing.T) { } func TestSetConversationContext(t *testing.T) { - convoHistory := []ConversationBit{ + convoHistory := []generated.ConversationBit{ { UserQuery: "test user query", BotResponse: "test bot response", diff --git a/server/db_test.go b/server/db_test.go index 237a45c..d342a67 100644 --- a/server/db_test.go +++ b/server/db_test.go @@ -8,6 +8,7 @@ import ( "time" "cloud.google.com/go/firestore" + "github.com/telpirion/MyHerodotus/generated" ) var ( @@ -29,18 +30,18 @@ func TestMain(m *testing.M) { collection := client.Collection(_collectionName) subcollection := collection.Doc(email2).Collection(SubCollectionName) - _, _, err = subcollection.Add(ctx, ConversationBit{ + _, _, err = subcollection.Add(ctx, generated.ConversationBit{ UserQuery: "test user query", BotResponse: "test bot response", - Created: time.Now(), + Created: time.Now().Unix(), }) if err != nil { log.Fatal(err) } - _, _, err = subcollection.Add(ctx, ConversationBit{ + _, _, err = subcollection.Add(ctx, generated.ConversationBit{ UserQuery: "test user query 2", BotResponse: "test bot response 2", - Created: time.Now(), + Created: time.Now().Unix(), }) if err != nil { log.Fatal(err) @@ -61,10 +62,10 @@ func TestMain(m *testing.M) { } func TestSaveConversation(t *testing.T) { - convo := &ConversationBit{ + convo := &generated.ConversationBit{ UserQuery: "This is from unit test", BotResponse: "This is a bot response", - Created: time.Now(), + Created: time.Now().Unix(), } id, err := saveConversation(*convo, email, _projectID) if err != nil { @@ -75,10 +76,10 @@ func TestSaveConversation(t *testing.T) { t.Error("Empty document ID") } - nextConvo := &ConversationBit{ + nextConvo := &generated.ConversationBit{ UserQuery: "This is also from a unit test", BotResponse: "This is another fake bot response", - Created: time.Now(), + Created: time.Now().Unix(), } nextID, err := saveConversation(*nextConvo, email, _projectID) if err != nil { From 1e7268f2e74d6d170e87262bcba7285ed4dc32a1 Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Mon, 2 Dec 2024 21:09:13 +0000 Subject: [PATCH 3/6] fix: updated embeddings, some FE fixes --- server/ai/embeddings.go | 37 +++++++++++++++++++++++++++++++++++++ server/ai/vertex.go | 34 ++++++++++++++++++++++++++++++++++ site/html/index.html | 1 + site/js/index.js | 4 ++++ 4 files changed, 76 insertions(+) create mode 100644 server/ai/embeddings.go diff --git a/server/ai/embeddings.go b/server/ai/embeddings.go new file mode 100644 index 0000000..7e28f72 --- /dev/null +++ b/server/ai/embeddings.go @@ -0,0 +1,37 @@ +package ai + +import ( + "context" + + "cloud.google.com/go/firestore" +) + +var testQuery = []float32{ + 0.02855090983211994, -0.035435665398836136, -0.030386969447135925, -0.02779807522892952, 0.06222547963261604, 0.032307568937540054, 0.006095572374761105, -0.017693594098091125, 0.03338219225406647, 0.052011892199516296, -0.0026677092537283897, -0.024139121174812317, 0.03690512105822563, -0.01943613775074482, -0.012934908270835876, -0.010278051719069481, 0.012921503745019436, 0.01350736990571022, 0.03510843962430954, -0.05104907602071762, -0.015654291957616806, 0.009973280131816864, 0.03959834948182106, 0.0084367161616683, 0.002778309164568782, -0.021186387166380882, -0.009058095514774323, -0.06550347059965134, -0.00020016275811940432, 0.03002917394042015, -0.0797576755285263, 0.01018990483134985, -0.07638117671012878, -0.004428969230502844, 0.02399042248725891, -0.027329476550221443, 0.01227719709277153, -0.009645448997616768, -0.022647777572274208, 0.02372140996158123, 0.00889495573937893, 0.044717513024806976, -0.0004786151403095573, 0.013710719533264637, 0.013255547732114792, -0.013190221041440964, -0.03924371674656868, -0.026655839756131172, -0.03915565460920334, -0.03143079951405525, 0.00671390863135457, -0.00036247173557057977, -0.019021255895495415, -0.007367352023720741, 0.023029474541544914, -0.05825888738036156, 0.07291541993618011, 0.01019354723393917, -0.03974568843841553, 0.028182122856378555, 0.02223936840891838, 0.025718091055750847, -0.009822058491408825, 0.0294659286737442, 0.045444194227457047, -0.05296202003955841, -0.03392653167247772, -0.028148580342531204, 0.08328374475240707, -0.005129489582031965, 0.04430880770087242, -0.005098880268633366, 0.026339663192629814, -0.002592273522168398, -0.03713815659284592, -0.05650784820318222, -0.00985332578420639, 0.06560691446065903, 0.04158855602145195, 0.0029465670231729746, 0.001012718304991722, -0.06649360060691833, -0.07900629937648773, -0.08239595592021942, -0.040955450385808945, -0.005131527315825224, -0.04057120159268379, -0.007040790747851133, -0.01463773287832737, 0.09051593393087387, 0.00812543649226427, 0.02343042939901352, 0.01811213046312332, -0.09880316257476807, -0.003164332127198577, 0.04705613851547241, 0.020249972119927406, -0.004456446506083012, -0.06126254051923752, -0.015342309139668941, -0.03331226855516434, -0.04447592794895172, -0.03512626513838768, -0.013713927939534187, 0.04527364671230316, 0.003716537728905678, 0.024795016273856163, 0.04758043587207794, 0.01533534936606884, 0.055674780160188675, -0.10856613516807556, 0.015341871418058872, 0.02090427279472351, -0.013197915628552437, -0.016130618751049042, -0.006240378133952618, -0.029776113107800484, 0.06432545930147171, -0.019419798627495766, -0.05422724038362503, 0.0129588283598423, -0.02371649071574211, 0.030254937708377838, 0.003299255622550845, 0.007495594676584005, 0.054419200867414474, 0.006257002707570791, 0.00789280142635107, 0.04340459406375885, 0.05346832424402237, -0.029386304318904877, -0.028848648071289062, 0.03731193393468857, 0.037310048937797546, 0.05145693197846413, 0.05570625886321068, 0.037225741893053055, -0.00999367143958807, 0.03503532335162163, 0.025289271026849747, 0.06072063371539116, 0.04446175694465637, -0.025419248268008232, -0.0026832991279661655, -0.01518651470541954, -0.013102266006171703, -0.03348962217569351, -0.007308039348572493, 0.025298891589045525, -0.020990043878555298, -0.0009921509772539139, -0.0012421614956110716, -0.021900959312915802, -0.02093563787639141, 0.12074049562215805, -0.0015050763031467795, 0.023123204708099365, 0.00012658465129788965, 0.01790185086429119, -0.03898821026086807, 0.048819027841091156, 0.02707771584391594, -0.02640371583402157, 0.020628876984119415, 0.020295720547437668, -0.04107430949807167, 0.051317933946847916, -0.035174526274204254, -0.01880166307091713, -0.005285138729959726, -0.013532846234738827, -0.041030798107385635, -0.05447480082511902, -0.006849688943475485, -0.06334996968507767, -0.021188974380493164, -0.049942947924137115, -0.05172604322433472, -0.02559443935751915, 0.01254394929856062, -0.04415528103709221, 0.00871290359646082, 0.00595787912607193, -0.02759362757205963, 0.07408977299928665, 0.022976472973823547, 0.0688139796257019, -0.06209009885787964, 0.0003142820205539465, -0.07047729939222336, 0.02597867324948311, 0.01738431304693222, -0.0272830743342638, 0.0016810598317533731, -0.015276334248483181, -0.00023947120644152164, 0.021162616088986397, 0.013781429268419743, -0.008525359444320202, -0.0742156133055687, -0.033778589218854904, 0.055541202425956726, -0.005239429883658886, -0.022989504039287567, 0.025311576202511787, 0.008473692461848259, 0.03572787344455719, -0.028337523341178894, -0.002087587956339121, 0.028023691847920418, -0.04168621823191643, 0.02454696223139763, -0.07450693845748901, 0.00983236264437437, -0.001956229330971837, -0.013308721594512463, -0.019844714552164078, 0.022057563066482544, 0.022406132891774178, -0.03200545161962509, -0.0007785293273627758, 0.009939076378941536, -0.034312572330236435, 0.02353755757212639, -0.0037626009434461594, 0.019475946202874184, 0.0039074718952178955, 0.014623363502323627, 0.006834374740719795, -0.08410128951072693, 0.018348829820752144, 0.031841207295656204, 0.06144539266824722, -0.043109212070703506, 0.02122630923986435, -0.03643026947975159, -0.012863168492913246, -0.005735444836318493, -0.016335545107722282, 0.01688990741968155, -0.033761266618967056, 0.04264049977064133, 0.04631083086133003, 0.02169719897210598, -0.04851929098367691, -0.013991246931254864, 0.00497107720002532, -0.005431650672107935, -0.0362272746860981, 0.05647943541407585, -0.015298440121114254, -0.034415390342473984, 0.03377275541424751, -0.01865268684923649, -0.009368045255541801, 0.015067036263644695, -0.030891407281160355, -0.008006050251424313, 0.033124275505542755, -0.009881931357085705, 0.04288339987397194, 0.01924636773765087, -0.006050598807632923, -0.018803521990776062, -0.019003525376319885, 0.013426613993942738, 0.03605456277728081, -0.05365649238228798, -0.005519316531717777, -0.006783206481486559, 0.0026525557041168213, -0.05311882495880127, 0.04736858978867531, -0.0016970753204077482, 0.001934107975102961, 0.051551878452301025, -0.04104161262512207, 0.07395049184560776, 0.04401092231273651, -0.06720613688230515, 0.03285341337323189, -0.018174340948462486, -0.005622238852083683, -0.04556065425276756, 0.0057927933521568775, -0.014021002687513828, -0.0026032484602183104, -0.0023746828082948923, 0.060974232852458954, -0.018062714487314224, -0.06138359382748604, -0.041578225791454315, -0.0369153767824173, -0.04500392824411392, -0.0424857996404171, -0.040178216993808746, -0.039900604635477066, 0.03514230623841286, 0.06981360167264938, -0.03095368854701519, -0.029267625883221626, -0.02522103674709797, -0.029793713241815567, -0.09786618500947952, 0.0005453719641081989, 0.04148014634847641, -0.01694026216864586, -0.029290150851011276, 0.007603716105222702, 0.01598203182220459, 0.06225724518299103, -0.05771584063768387, 0.01989571750164032, 0.022766467183828354, 0.03523332253098488, 0.07264762371778488, -0.023715097457170486, 0.029023749753832817, -0.007673482410609722, -0.007403539028018713, 0.005585302598774433, 0.017769986763596535, 0.05035913363099098, -0.015977585688233376, 0.0016169166192412376, 0.006704722996801138, 0.010514985769987106, 0.05035238340497017, -0.02595062367618084, 0.054979898035526276, -0.046891532838344574, -0.004900076426565647, 0.013661564327776432, 0.0348597951233387, 0.03140031918883324, 0.01243430282920599, -0.07501792907714844, -0.024661434814333916, 0.006259022746235132, -1.5248648196575232e-05, 0.04560282081365585, 0.004388540983200073, 0.001906555495224893, -0.0032416072208434343, 0.013940363191068172, 0.00021165529324207455, -0.011592516675591469, -0.016764340922236443, 0.06575356423854828, 0.027993930503726006, -0.01378786750137806, 0.10314197838306427, -0.07857607305049896, 0.022680073976516724, -0.0009134305291809142, -0.06318723410367966, 0.04860355332493782, -0.03365032747387886, 0.025440065190196037, -0.03338564559817314, -0.052476998418569565, -0.01195498276501894, -0.03008648194372654, 0.0033432694617658854, -0.008713409304618835, -0.012708612717688084, 0.0004020403139293194, 0.006438801530748606, -0.02223292365670204, 0.05535142868757248, -0.00854291208088398, -0.014777948148548603, 0.01598581299185753, -0.034732360392808914, 0.0032275253906846046, -0.05101357772946358, 0.022697696462273598, -0.027118192985653877, 0.007723767310380936, -0.014824559912085533, -0.058855682611465454, 0.02635515108704567, 0.04334668815135956, 0.044087205082178116, 0.03271558880805969, -0.03411280736327171, 0.017738979309797287, -0.007958486676216125, -0.03788032755255699, 0.0381910502910614, 0.00902782753109932, 0.07206746935844421, 0.10690091550350189, 0.007424680050462484, 0.0034760432317852974, -0.04496564343571663, 0.02307494729757309, 0.001119546708650887, 0.040706176310777664, 0.004877446684986353, 0.03030811995267868, 0.016987493261694908, -0.05734526365995407, -0.010726270265877247, -0.005477061495184898, -0.012338162399828434, 0.023079147562384605, 0.0064426325261592865, 0.018023379147052765, -0.001993125304579735, -0.008092779666185379, -0.015819448977708817, -0.0075140465050935745, -0.03286697715520859, -0.07405697554349899, -0.06511569023132324, -0.019981998950242996, -0.01954357698559761, 0.013753838837146759, -0.001383290276862681, -0.006645305082201958, -0.006058522500097752, 0.002293882193043828, -0.010358012281358242, -0.03823737055063248, -0.07208363711833954, -0.008938144892454147, 0.00370803102850914, 0.0038766139186918736, 0.011266392655670643, 0.03995537385344505, 0.03416783735156059, -0.0012853029184043407, 0.024797631427645683, -0.012933261692523956, -0.01576933078467846, 0.004807986319065094, 0.03506812825798988, -0.03838364779949188, -0.006483998149633408, 0.031055545434355736, -0.04745015501976013, 0.07048466056585312, -0.07552384585142136, -0.027246717363595963, -0.06830760836601257, 0.006882162764668465, -0.024721650406718254, 0.09718921035528183, -0.018559766933321953, -0.0011288024252280593, -0.05657573416829109, -0.0700235366821289, -0.05660148337483406, -0.0096006840467453, 0.01166986022144556, -0.04477355629205704, 0.031496185809373856, -0.011987220495939255, -0.0010876517044380307, -0.028072169050574303, -0.03993160277605057, 0.013473179191350937, -0.04584147408604622, 0.022103281691670418, 0.008277908898890018, -0.0038070273585617542, -0.02037445642054081, 0.06145468354225159, 0.017347680404782295, 0.0017571731004863977, -0.010877756401896477, -0.0019243310671299696, 0.02593608945608139, -0.08312691748142242, -0.03769318386912346, -0.056560784578323364, 0.004632606636732817, -0.034733597189188004, -0.03228360414505005, -0.025840453803539276, -0.0038050375878810883, 0.0016368301585316658, 0.024374380707740784, 0.02495880238711834, 0.019468912854790688, -0.013268990442156792, -0.04536015912890434, 0.01596578024327755, -0.02285446785390377, -0.011797250248491764, -0.03853283077478409, -0.026846427470445633, 0.003130924655124545, -0.027369404211640358, -0.019355200231075287, 0.016502078622579575, 0.03332747891545296, 0.02320251055061817, 0.019672643393278122, -0.010225689969956875, -0.04968182370066643, -0.03328308090567589, 0.008075451478362083, 0.1078341156244278, -0.03717776760458946, 0.004175296984612942, 0.043597057461738586, 0.025012167170643806, 0.015302439220249653, -0.03030376322567463, 0.006578406319022179, 0.040074367076158524, -0.032332926988601685, -0.020150942727923393, -0.0046142758801579475, -0.0298289954662323, 0.022746548056602478, -0.02541794814169407, 0.009955331683158875, -0.025215981528162956, 0.001477666082791984, -0.05946853384375572, -0.030336374416947365, 0.014915398322045803, -0.09102897346019745, 0.010123025625944138, 0.009294637478888035, -0.007701405789703131, 0.008178695105016232, -0.020562512800097466, 0.10714934766292572, -0.01200184877961874, 0.0015624441439285874, -0.010494988411664963, -0.03359261527657509, 0.035020794719457626, 0.026205500587821007, 0.040460992604494095, -0.03342648223042488, 0.05205344781279564, 0.007288525812327862, 0.03261154145002365, -0.01687617041170597, -0.01126027200371027, -0.01774979755282402, 0.0029218143317848444, -0.060196325182914734, 0.02289787121117115, -0.024223385378718376, -0.003515620017424226, -0.05089220777153969, 0.04268461838364601, -0.05291137844324112, 0.031066900119185448, 0.016818968579173088, -0.005864706356078386, -0.03806670382618904, -0.0066114491783082485, 0.045578483492136, 0.026174211874604225, -0.011020909063518047, -0.03436892852187157, -0.03348691761493683, 0.021527811884880066, 0.012673843652009964, 0.0020246717613190413, -0.027109863236546516, 0.07844027876853943, -0.01014805119484663, 0.019576648250222206, -0.006293583195656538, -0.048459574580192566, 0.03393078222870827, 0.04574711620807648, -0.06010681763291359, -0.025319358333945274, 0.098772794008255, -0.03859378397464752, -0.030692782253026962, 0.007690828293561935, 0.02825438603758812, -0.014603069983422756, 0.02322860062122345, -0.009207343682646751, 0.06143806502223015, 0.010934746824204922, -0.021852808073163033, 0.055165521800518036, -0.009216084145009518, -0.06793341785669327, 0.04533921554684639, -0.048080924898386, 0.022592894732952118, 0.033068541437387466, 0.007881698198616505, 0.025631239637732506, 0.00029027427081018686, -0.0561126172542572, -0.007271422538906336, 0.05536895617842674, -0.05115574970841408, 0.042531296610832214, -0.025197723880410194, 0.01924409158527851, -0.0027287292759865522, -0.00507039949297905, -0.02539922297000885, -0.010107752867043018, 0.010699602775275707, -0.0382370762526989, -0.01664213091135025, -0.047089584171772, -0.0034264489077031612, -0.06642859429121017, -0.03939779847860336, 0.03560569882392883, 0.00872077327221632, -0.0036075676325708628, -0.03039364702999592, 0.012129309587180614, 0.012651915661990643, 0.009354270994663239, 0.02331015281379223, 0.05055489391088486, -0.022661900147795677, 0.00035730405943468213, 0.0005964897572994232, 0.03398798033595085, 0.04327566176652908, 0.09301704168319702, 0.03518136963248253, -0.0037557317409664392, -0.009706447832286358, -0.06007389351725578, 0.04447989538311958, -0.04350242018699646, -0.017650479450821877, 0.0252322256565094, -0.02052578330039978, -0.06214236468076706, 0.0038642920553684235, 0.02526138350367546, -0.055183544754981995, 0.0270160473883152, 0.059157948940992355, -0.0441071093082428, -0.07567676156759262, -0.006776335649192333, 0.020049113780260086, -0.044227126985788345, 0.010548129677772522, 0.023682815954089165, -0.002461734227836132, 0.029462721198797226, -0.029845494776964188, -0.033939626067876816, -0.051410745829343796, 0.024170760065317154, 0.036312803626060486, -0.06342709809541702, -0.06829972565174103, -0.040413957089185715, -0.006245889700949192, 0.03030882030725479, -0.054809100925922394, -0.06274536997079849, -0.0717836245894432, -0.0009123955969698727, 0.019963543862104416, -0.0375322587788105, 0.053369540721178055, 0.020385492593050003, 0.036196269094944, 0.0587073415517807, -0.043831467628479004, -0.021353349089622498, -0.003790423506870866, 0.009372469037771225, 0.05401051789522171, -0.02122296765446663, -0.025599410757422447, -0.024540526792407036, 0.003171027172356844, 0.01038559153676033, 0.06397418677806854, -0.007261837832629681, 0.008491010405123234, 0.01663254387676716, -0.014911594800651073, 0.050654370337724686, 0.048769839107990265, -0.015214084647595882, 0.05247296020388603, 0.057401981204748154, -0.0187350045889616, -0.01051677018404007, -0.05700339749455452, -0.0035918855573982, 0.02362808585166931, 0.0001396903971908614, -0.022886553779244423, -0.010406074114143848, 0.016252927482128143, -0.0036992013920098543, 0.005374113097786903, 0.04946666210889816, 0.012444490566849709, 0.004213852807879448, 0.03760175779461861, 0.018171176314353943, -0.0313417948782444, -0.02914964035153389, 0.02928393892943859, -0.008435087278485298, 0.0027362164109945297, 0.024289576336741447, 0.0241272933781147, -0.041493237018585205, 0.044811904430389404, 0.02856479585170746, -0.04145795851945877, -0.00435135280713439, -0.06780983507633209, -0.048362985253334045, 0.013448046520352364, -0.027512947097420692, 0.07810564339160919, 0.020975375548005104, 0.027971159666776657, 0.010725722648203373, -0.04188999906182289, 0.007460396271198988, 0.030039574950933456, 0.0516509972512722, 0.1017109751701355, -0.009912806563079357, 0.018026815727353096, 0.0017937867669388652, -0.07013159245252609, -0.03783680498600006, 0.05119873955845833, 0.02860427275300026, 0.05304737761616707, 0.032938648015260696, -0.0352289192378521, -0.05801593139767647, -0.0216226764023304, -0.04440734162926674, -0.046590279787778854, -0.003245545784011483, 0.011306687258183956, 0.03128962218761444, 0.033387381583452225, 0.04227878898382187, -0.03396473452448845, 0.006454319227486849, -0.01405749935656786, -0.006435456685721874, 0.054593607783317566, -0.05222899839282036, 0.03660539537668228, 0.0009407529723830521, 0.03186017647385597, 0.017225367948412895, 0.05820940062403679, -0.03340610861778259, 0.0026137330569326878, +} + +func getEmbedding(query, projectID string) (string, error) { + ctx := context.Background() + vectorDB, err := firestore.NewClientWithDatabase(ctx, projectID, "embedding-db") + if err != nil { + return "", err + } + defer vectorDB.Close() + + collection := vectorDB.Collection("Histories") + vectorQuery := collection.FindNearest("embedding", + testQuery, + 5, + firestore.DistanceMeasureEuclidean, + nil) + docs, err := vectorQuery.Documents(ctx).GetAll() + if err != nil { + return "", err + } + + output := "" + for _, doc := range docs { + output += doc.Data()["content"].(string) + } + return output, nil +} diff --git a/server/ai/vertex.go b/server/ai/vertex.go index d2ad968..a2cb7db 100644 --- a/server/ai/vertex.go +++ b/server/ai/vertex.go @@ -81,6 +81,8 @@ func Predict(query, modality, projectID string) (response string, templateName s response, err = textPredictGemini(query, projectID, GeminiTuned) case AgentAssisted: response, err = textPredictWithReddit(query, projectID) + case EmbeddingsAssisted: + response, err = textPredictWithEmbeddings(query, projectID) default: response, err = textPredictGemini(query, projectID, Gemini) } @@ -397,3 +399,35 @@ func textPredictWithReddit(query, projectID string) (string, error) { output := string(res.Candidates[0].Content.Parts[0].(genai.Text)) return output, nil } + +func textPredictWithEmbeddings(query, projectID string) (string, error) { + + // Get context from embeddings + embeddingContext, err := getEmbedding(query, projectID) + if err != nil { + return "", err + } + + ctx := context.Background() + client, err := genai.NewClient(ctx, projectID, "us-west1") + if err != nil { + return "", err + } + defer client.Close() + + llm := client.GenerativeModel(GeminiModel) + + createPrompt(query, GeminiTemplate, embeddingContext) + + resp, err := llm.GenerateContent(ctx, genai.Text(query)) + if err != nil { + fmt.Println(err.Error()) + return "", err + } + + candidate, err := getCandidate(resp) + if err != nil { + return "I'm not sure how to answer that. Would you please repeat the question?", nil + } + return extractAnswer(candidate), nil +} diff --git a/site/html/index.html b/site/html/index.html index 80951ab..89cd060 100644 --- a/site/html/index.html +++ b/site/html/index.html @@ -35,6 +35,7 @@ + diff --git a/site/js/index.js b/site/js/index.js index bcaf2a4..c812ab7 100644 --- a/site/js/index.js +++ b/site/js/index.js @@ -86,6 +86,10 @@ function processForm(e) { // Collect data const message = document.getElementById("userMsg").value; + if (message === "") { + return; + } + const selection = document.getElementById("modelSelect"); const model = selection.options[selection.selectedIndex].text; From 1d6f09ff82f73ee287d6ecd8a802a0c986de51a0 Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Fri, 6 Dec 2024 13:55:38 -0800 Subject: [PATCH 4/6] feat: integrated embeddings mode --- server/ai/vertex.go | 81 ++++++++++++++++++++++---- server/{ => databases}/db.go | 14 ++--- server/{ => databases}/db_test.go | 8 +-- server/{ai => databases}/embeddings.go | 24 +++++--- server/main.go | 12 ++-- site/js/index.js | 12 ++-- 6 files changed, 109 insertions(+), 42 deletions(-) rename server/{ => databases}/db.go (86%) rename server/{ => databases}/db_test.go (91%) rename server/{ai => databases}/embeddings.go (97%) diff --git a/server/ai/vertex.go b/server/ai/vertex.go index a2cb7db..6004227 100644 --- a/server/ai/vertex.go +++ b/server/ai/vertex.go @@ -18,6 +18,7 @@ import ( "google.golang.org/api/option" "google.golang.org/protobuf/types/known/structpb" + db "github.com/telpirion/MyHerodotus/databases" "github.com/telpirion/MyHerodotus/generated" ) @@ -28,7 +29,9 @@ const ( GemmaTemplate = "templates/gemma.2024.10.25.tmpl" GeminiModel = "gemini-1.5-flash-001" HistoryTemplate = "templates/conversation_history.tmpl" + EmbeddingModelName = "text-embedding-005" MaxGemmaTokens int32 = 2048 + location = "us-west1" ) var ( @@ -88,7 +91,7 @@ func Predict(query, modality, projectID string) (response string, templateName s } if err != nil { - return "", "", nil + return "", "", err } cachedContext += fmt.Sprintf("### Human: %s\n### Assistant: %s\n", query, response) @@ -97,7 +100,6 @@ func Predict(query, modality, projectID string) (response string, templateName s // GetTokenCount uses the Gemini tokenizer to count the tokens in some text. func GetTokenCount(text, projectID string) (int32, error) { - location := "us-west1" ctx := context.Background() client, err := genai.NewClient(ctx, projectID, location) if err != nil { @@ -137,7 +139,6 @@ func StoreConversationContext(conversationHistory []generated.ConversationBit, p } ctx := context.Background() - location := "us-west1" client, err := genai.NewClient(ctx, projectID, location) if err != nil { return "", fmt.Errorf("unable to create client: %w", err) @@ -210,7 +211,6 @@ func createPrompt(message, templateName, history string) (string, error) { // textPredictGemma2 generates text using a Gemma2 hosted model func textPredictGemma(message, projectID string) (string, error) { ctx := context.Background() - location := "us-west1" endpointID := os.Getenv("ENDPOINT_ID") gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID) @@ -267,8 +267,6 @@ func textPredictGemma(message, projectID string) (string, error) { // textPredictGemini generates text using a Gemini 1.5 Flash model func textPredictGemini(message, projectID string, modality Modality) (string, error) { ctx := context.Background() - location := "us-west1" - client, err := genai.NewClient(ctx, projectID, location) if err != nil { return "", err @@ -298,7 +296,7 @@ func textPredictGemini(message, projectID string, modality Modality) (string, er candidate, err := getCandidate(resp) if err != nil { - return "I'm not sure how to answer that. Would you please repeat the question?", nil + return "", nil } return extractAnswer(candidate), nil } @@ -402,14 +400,19 @@ func textPredictWithReddit(query, projectID string) (string, error) { func textPredictWithEmbeddings(query, projectID string) (string, error) { + queryEmbed, err := getQueryTextEmbedding(query, projectID) + if err != nil { + return "", err + } + // Get context from embeddings - embeddingContext, err := getEmbedding(query, projectID) + embeddingContext, err := db.GetEmbedding(queryEmbed, projectID) if err != nil { return "", err } ctx := context.Background() - client, err := genai.NewClient(ctx, projectID, "us-west1") + client, err := genai.NewClient(ctx, projectID, location) if err != nil { return "", err } @@ -427,7 +430,65 @@ func textPredictWithEmbeddings(query, projectID string) (string, error) { candidate, err := getCandidate(resp) if err != nil { - return "I'm not sure how to answer that. Would you please repeat the question?", nil + return "", err } return extractAnswer(candidate), nil } + +func getQueryTextEmbedding(query, projectID string) ([]float32, error) { + + var embedding []float32 + ctx := context.Background() + + apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) + dimensionality := 128 + texts := []string{query} + + client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint)) + if err != nil { + return embedding, err + } + defer client.Close() + + endpoint := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", + projectID, location, EmbeddingModelName) + instances := make([]*structpb.Value, len(texts)) + for i, text := range texts { + instances[i] = structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "content": structpb.NewStringValue(text), + "task_type": structpb.NewStringValue("RETRIEVAL_QUERY"), + }, + }) + } + + params := structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "outputDimensionality": structpb.NewNumberValue(float64(dimensionality)), + }, + }) + + req := &aiplatformpb.PredictRequest{ + Endpoint: endpoint, + Instances: instances, + Parameters: params, + } + resp, err := client.Predict(ctx, req) + if err != nil { + return embedding, err + } + embeddings := make([][]float32, len(resp.Predictions)) + for i, prediction := range resp.Predictions { + values := prediction.GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values + embeddings[i] = make([]float32, len(values)) + for j, value := range values { + embeddings[i][j] = float32(value.GetNumberValue()) + } + } + + if len(embeddings) == 0 { + return embedding, fmt.Errorf("vertex: text embeddings: no embeddings created") + } + + return embeddings[0], nil +} diff --git a/server/db.go b/server/databases/db.go similarity index 86% rename from server/db.go rename to server/databases/db.go index d1c10d3..9ecb589 100644 --- a/server/db.go +++ b/server/databases/db.go @@ -25,11 +25,10 @@ Herodotus [ ] */ -package main +package databases import ( "context" - "fmt" "os" "cloud.google.com/go/firestore" @@ -60,7 +59,7 @@ type ConversationHistory struct { Conversations []generated.ConversationBit } -func saveConversation(convo generated.ConversationBit, userEmail, projectID string) (string, error) { +func SaveConversation(convo generated.ConversationBit, userEmail, projectID string) (string, error) { ctx := context.Background() // Get CollectionName for running in staging or prod @@ -71,7 +70,6 @@ func saveConversation(convo generated.ConversationBit, userEmail, projectID stri client, err := firestore.NewClientWithDatabase(ctx, projectID, DBName) if err != nil { - LogError(fmt.Sprintf("firestore.Client: %v\n", err)) return "", err } defer client.Close() @@ -83,7 +81,7 @@ func saveConversation(convo generated.ConversationBit, userEmail, projectID stri return docRef.ID, err } -func updateConversation(documentId, userEmail, rating, projectID string) error { +func UpdateConversation(documentId, userEmail, rating, projectID string) error { // Get CollectionName for running in staging or prod _collectionName, ok := os.LookupEnv("COLLECTION_NAME") @@ -106,12 +104,11 @@ func updateConversation(documentId, userEmail, rating, projectID string) error { return nil } -func getConversation(userEmail, projectID string) ([]generated.ConversationBit, error) { +func GetConversation(userEmail, projectID string) ([]generated.ConversationBit, error) { ctx := context.Background() conversations := []generated.ConversationBit{} client, err := firestore.NewClientWithDatabase(ctx, projectID, DBName) if err != nil { - LogError(fmt.Sprintf("firestore.Client: %v\n", err)) return conversations, err } defer client.Close() @@ -125,13 +122,12 @@ func getConversation(userEmail, projectID string) ([]generated.ConversationBit, break } if err != nil { - LogError(fmt.Sprintf("Firestore Iterator: %v\n", err)) return conversations, err } var convo generated.ConversationBit err = doc.DataTo(&convo) if err != nil { - LogError(fmt.Sprintf("Firestore document unmarshaling: %v\n", err)) + continue } conversations = append(conversations, convo) diff --git a/server/db_test.go b/server/databases/db_test.go similarity index 91% rename from server/db_test.go rename to server/databases/db_test.go index d342a67..38a81e2 100644 --- a/server/db_test.go +++ b/server/databases/db_test.go @@ -1,4 +1,4 @@ -package main +package databases import ( "context" @@ -67,7 +67,7 @@ func TestSaveConversation(t *testing.T) { BotResponse: "This is a bot response", Created: time.Now().Unix(), } - id, err := saveConversation(*convo, email, _projectID) + id, err := SaveConversation(*convo, email, _projectID) if err != nil { t.Fatal(err) } @@ -81,7 +81,7 @@ func TestSaveConversation(t *testing.T) { BotResponse: "This is another fake bot response", Created: time.Now().Unix(), } - nextID, err := saveConversation(*nextConvo, email, _projectID) + nextID, err := SaveConversation(*nextConvo, email, _projectID) if err != nil { t.Fatalf("Error on adding next conversation: %v\n\n", err) } @@ -91,7 +91,7 @@ func TestSaveConversation(t *testing.T) { } func TestGetConversation(t *testing.T) { - conversations, err := getConversation(email2, projectID) + conversations, err := GetConversation(email2, _projectID) if err != nil { t.Fatal(err) } diff --git a/server/ai/embeddings.go b/server/databases/embeddings.go similarity index 97% rename from server/ai/embeddings.go rename to server/databases/embeddings.go index 7e28f72..5cd074b 100644 --- a/server/ai/embeddings.go +++ b/server/databases/embeddings.go @@ -1,4 +1,4 @@ -package ai +package databases import ( "context" @@ -10,18 +10,26 @@ var testQuery = []float32{ 0.02855090983211994, -0.035435665398836136, -0.030386969447135925, -0.02779807522892952, 0.06222547963261604, 0.032307568937540054, 0.006095572374761105, -0.017693594098091125, 0.03338219225406647, 0.052011892199516296, -0.0026677092537283897, -0.024139121174812317, 0.03690512105822563, -0.01943613775074482, -0.012934908270835876, -0.010278051719069481, 0.012921503745019436, 0.01350736990571022, 0.03510843962430954, -0.05104907602071762, -0.015654291957616806, 0.009973280131816864, 0.03959834948182106, 0.0084367161616683, 0.002778309164568782, -0.021186387166380882, -0.009058095514774323, -0.06550347059965134, -0.00020016275811940432, 0.03002917394042015, -0.0797576755285263, 0.01018990483134985, -0.07638117671012878, -0.004428969230502844, 0.02399042248725891, -0.027329476550221443, 0.01227719709277153, -0.009645448997616768, -0.022647777572274208, 0.02372140996158123, 0.00889495573937893, 0.044717513024806976, -0.0004786151403095573, 0.013710719533264637, 0.013255547732114792, -0.013190221041440964, -0.03924371674656868, -0.026655839756131172, -0.03915565460920334, -0.03143079951405525, 0.00671390863135457, -0.00036247173557057977, -0.019021255895495415, -0.007367352023720741, 0.023029474541544914, -0.05825888738036156, 0.07291541993618011, 0.01019354723393917, -0.03974568843841553, 0.028182122856378555, 0.02223936840891838, 0.025718091055750847, -0.009822058491408825, 0.0294659286737442, 0.045444194227457047, -0.05296202003955841, -0.03392653167247772, -0.028148580342531204, 0.08328374475240707, -0.005129489582031965, 0.04430880770087242, -0.005098880268633366, 0.026339663192629814, -0.002592273522168398, -0.03713815659284592, -0.05650784820318222, -0.00985332578420639, 0.06560691446065903, 0.04158855602145195, 0.0029465670231729746, 0.001012718304991722, -0.06649360060691833, -0.07900629937648773, -0.08239595592021942, -0.040955450385808945, -0.005131527315825224, -0.04057120159268379, -0.007040790747851133, -0.01463773287832737, 0.09051593393087387, 0.00812543649226427, 0.02343042939901352, 0.01811213046312332, -0.09880316257476807, -0.003164332127198577, 0.04705613851547241, 0.020249972119927406, -0.004456446506083012, -0.06126254051923752, -0.015342309139668941, -0.03331226855516434, -0.04447592794895172, -0.03512626513838768, -0.013713927939534187, 0.04527364671230316, 0.003716537728905678, 0.024795016273856163, 0.04758043587207794, 0.01533534936606884, 0.055674780160188675, -0.10856613516807556, 0.015341871418058872, 0.02090427279472351, -0.013197915628552437, -0.016130618751049042, -0.006240378133952618, -0.029776113107800484, 0.06432545930147171, -0.019419798627495766, -0.05422724038362503, 0.0129588283598423, -0.02371649071574211, 0.030254937708377838, 0.003299255622550845, 0.007495594676584005, 0.054419200867414474, 0.006257002707570791, 0.00789280142635107, 0.04340459406375885, 0.05346832424402237, -0.029386304318904877, -0.028848648071289062, 0.03731193393468857, 0.037310048937797546, 0.05145693197846413, 0.05570625886321068, 0.037225741893053055, -0.00999367143958807, 0.03503532335162163, 0.025289271026849747, 0.06072063371539116, 0.04446175694465637, -0.025419248268008232, -0.0026832991279661655, -0.01518651470541954, -0.013102266006171703, -0.03348962217569351, -0.007308039348572493, 0.025298891589045525, -0.020990043878555298, -0.0009921509772539139, -0.0012421614956110716, -0.021900959312915802, -0.02093563787639141, 0.12074049562215805, -0.0015050763031467795, 0.023123204708099365, 0.00012658465129788965, 0.01790185086429119, -0.03898821026086807, 0.048819027841091156, 0.02707771584391594, -0.02640371583402157, 0.020628876984119415, 0.020295720547437668, -0.04107430949807167, 0.051317933946847916, -0.035174526274204254, -0.01880166307091713, -0.005285138729959726, -0.013532846234738827, -0.041030798107385635, -0.05447480082511902, -0.006849688943475485, -0.06334996968507767, -0.021188974380493164, -0.049942947924137115, -0.05172604322433472, -0.02559443935751915, 0.01254394929856062, -0.04415528103709221, 0.00871290359646082, 0.00595787912607193, -0.02759362757205963, 0.07408977299928665, 0.022976472973823547, 0.0688139796257019, -0.06209009885787964, 0.0003142820205539465, -0.07047729939222336, 0.02597867324948311, 0.01738431304693222, -0.0272830743342638, 0.0016810598317533731, -0.015276334248483181, -0.00023947120644152164, 0.021162616088986397, 0.013781429268419743, -0.008525359444320202, -0.0742156133055687, -0.033778589218854904, 0.055541202425956726, -0.005239429883658886, -0.022989504039287567, 0.025311576202511787, 0.008473692461848259, 0.03572787344455719, -0.028337523341178894, -0.002087587956339121, 0.028023691847920418, -0.04168621823191643, 0.02454696223139763, -0.07450693845748901, 0.00983236264437437, -0.001956229330971837, -0.013308721594512463, -0.019844714552164078, 0.022057563066482544, 0.022406132891774178, -0.03200545161962509, -0.0007785293273627758, 0.009939076378941536, -0.034312572330236435, 0.02353755757212639, -0.0037626009434461594, 0.019475946202874184, 0.0039074718952178955, 0.014623363502323627, 0.006834374740719795, -0.08410128951072693, 0.018348829820752144, 0.031841207295656204, 0.06144539266824722, -0.043109212070703506, 0.02122630923986435, -0.03643026947975159, -0.012863168492913246, -0.005735444836318493, -0.016335545107722282, 0.01688990741968155, -0.033761266618967056, 0.04264049977064133, 0.04631083086133003, 0.02169719897210598, -0.04851929098367691, -0.013991246931254864, 0.00497107720002532, -0.005431650672107935, -0.0362272746860981, 0.05647943541407585, -0.015298440121114254, -0.034415390342473984, 0.03377275541424751, -0.01865268684923649, -0.009368045255541801, 0.015067036263644695, -0.030891407281160355, -0.008006050251424313, 0.033124275505542755, -0.009881931357085705, 0.04288339987397194, 0.01924636773765087, -0.006050598807632923, -0.018803521990776062, -0.019003525376319885, 0.013426613993942738, 0.03605456277728081, -0.05365649238228798, -0.005519316531717777, -0.006783206481486559, 0.0026525557041168213, -0.05311882495880127, 0.04736858978867531, -0.0016970753204077482, 0.001934107975102961, 0.051551878452301025, -0.04104161262512207, 0.07395049184560776, 0.04401092231273651, -0.06720613688230515, 0.03285341337323189, -0.018174340948462486, -0.005622238852083683, -0.04556065425276756, 0.0057927933521568775, -0.014021002687513828, -0.0026032484602183104, -0.0023746828082948923, 0.060974232852458954, -0.018062714487314224, -0.06138359382748604, -0.041578225791454315, -0.0369153767824173, -0.04500392824411392, -0.0424857996404171, -0.040178216993808746, -0.039900604635477066, 0.03514230623841286, 0.06981360167264938, -0.03095368854701519, -0.029267625883221626, -0.02522103674709797, -0.029793713241815567, -0.09786618500947952, 0.0005453719641081989, 0.04148014634847641, -0.01694026216864586, -0.029290150851011276, 0.007603716105222702, 0.01598203182220459, 0.06225724518299103, -0.05771584063768387, 0.01989571750164032, 0.022766467183828354, 0.03523332253098488, 0.07264762371778488, -0.023715097457170486, 0.029023749753832817, -0.007673482410609722, -0.007403539028018713, 0.005585302598774433, 0.017769986763596535, 0.05035913363099098, -0.015977585688233376, 0.0016169166192412376, 0.006704722996801138, 0.010514985769987106, 0.05035238340497017, -0.02595062367618084, 0.054979898035526276, -0.046891532838344574, -0.004900076426565647, 0.013661564327776432, 0.0348597951233387, 0.03140031918883324, 0.01243430282920599, -0.07501792907714844, -0.024661434814333916, 0.006259022746235132, -1.5248648196575232e-05, 0.04560282081365585, 0.004388540983200073, 0.001906555495224893, -0.0032416072208434343, 0.013940363191068172, 0.00021165529324207455, -0.011592516675591469, -0.016764340922236443, 0.06575356423854828, 0.027993930503726006, -0.01378786750137806, 0.10314197838306427, -0.07857607305049896, 0.022680073976516724, -0.0009134305291809142, -0.06318723410367966, 0.04860355332493782, -0.03365032747387886, 0.025440065190196037, -0.03338564559817314, -0.052476998418569565, -0.01195498276501894, -0.03008648194372654, 0.0033432694617658854, -0.008713409304618835, -0.012708612717688084, 0.0004020403139293194, 0.006438801530748606, -0.02223292365670204, 0.05535142868757248, -0.00854291208088398, -0.014777948148548603, 0.01598581299185753, -0.034732360392808914, 0.0032275253906846046, -0.05101357772946358, 0.022697696462273598, -0.027118192985653877, 0.007723767310380936, -0.014824559912085533, -0.058855682611465454, 0.02635515108704567, 0.04334668815135956, 0.044087205082178116, 0.03271558880805969, -0.03411280736327171, 0.017738979309797287, -0.007958486676216125, -0.03788032755255699, 0.0381910502910614, 0.00902782753109932, 0.07206746935844421, 0.10690091550350189, 0.007424680050462484, 0.0034760432317852974, -0.04496564343571663, 0.02307494729757309, 0.001119546708650887, 0.040706176310777664, 0.004877446684986353, 0.03030811995267868, 0.016987493261694908, -0.05734526365995407, -0.010726270265877247, -0.005477061495184898, -0.012338162399828434, 0.023079147562384605, 0.0064426325261592865, 0.018023379147052765, -0.001993125304579735, -0.008092779666185379, -0.015819448977708817, -0.0075140465050935745, -0.03286697715520859, -0.07405697554349899, -0.06511569023132324, -0.019981998950242996, -0.01954357698559761, 0.013753838837146759, -0.001383290276862681, -0.006645305082201958, -0.006058522500097752, 0.002293882193043828, -0.010358012281358242, -0.03823737055063248, -0.07208363711833954, -0.008938144892454147, 0.00370803102850914, 0.0038766139186918736, 0.011266392655670643, 0.03995537385344505, 0.03416783735156059, -0.0012853029184043407, 0.024797631427645683, -0.012933261692523956, -0.01576933078467846, 0.004807986319065094, 0.03506812825798988, -0.03838364779949188, -0.006483998149633408, 0.031055545434355736, -0.04745015501976013, 0.07048466056585312, -0.07552384585142136, -0.027246717363595963, -0.06830760836601257, 0.006882162764668465, -0.024721650406718254, 0.09718921035528183, -0.018559766933321953, -0.0011288024252280593, -0.05657573416829109, -0.0700235366821289, -0.05660148337483406, -0.0096006840467453, 0.01166986022144556, -0.04477355629205704, 0.031496185809373856, -0.011987220495939255, -0.0010876517044380307, -0.028072169050574303, -0.03993160277605057, 0.013473179191350937, -0.04584147408604622, 0.022103281691670418, 0.008277908898890018, -0.0038070273585617542, -0.02037445642054081, 0.06145468354225159, 0.017347680404782295, 0.0017571731004863977, -0.010877756401896477, -0.0019243310671299696, 0.02593608945608139, -0.08312691748142242, -0.03769318386912346, -0.056560784578323364, 0.004632606636732817, -0.034733597189188004, -0.03228360414505005, -0.025840453803539276, -0.0038050375878810883, 0.0016368301585316658, 0.024374380707740784, 0.02495880238711834, 0.019468912854790688, -0.013268990442156792, -0.04536015912890434, 0.01596578024327755, -0.02285446785390377, -0.011797250248491764, -0.03853283077478409, -0.026846427470445633, 0.003130924655124545, -0.027369404211640358, -0.019355200231075287, 0.016502078622579575, 0.03332747891545296, 0.02320251055061817, 0.019672643393278122, -0.010225689969956875, -0.04968182370066643, -0.03328308090567589, 0.008075451478362083, 0.1078341156244278, -0.03717776760458946, 0.004175296984612942, 0.043597057461738586, 0.025012167170643806, 0.015302439220249653, -0.03030376322567463, 0.006578406319022179, 0.040074367076158524, -0.032332926988601685, -0.020150942727923393, -0.0046142758801579475, -0.0298289954662323, 0.022746548056602478, -0.02541794814169407, 0.009955331683158875, -0.025215981528162956, 0.001477666082791984, -0.05946853384375572, -0.030336374416947365, 0.014915398322045803, -0.09102897346019745, 0.010123025625944138, 0.009294637478888035, -0.007701405789703131, 0.008178695105016232, -0.020562512800097466, 0.10714934766292572, -0.01200184877961874, 0.0015624441439285874, -0.010494988411664963, -0.03359261527657509, 0.035020794719457626, 0.026205500587821007, 0.040460992604494095, -0.03342648223042488, 0.05205344781279564, 0.007288525812327862, 0.03261154145002365, -0.01687617041170597, -0.01126027200371027, -0.01774979755282402, 0.0029218143317848444, -0.060196325182914734, 0.02289787121117115, -0.024223385378718376, -0.003515620017424226, -0.05089220777153969, 0.04268461838364601, -0.05291137844324112, 0.031066900119185448, 0.016818968579173088, -0.005864706356078386, -0.03806670382618904, -0.0066114491783082485, 0.045578483492136, 0.026174211874604225, -0.011020909063518047, -0.03436892852187157, -0.03348691761493683, 0.021527811884880066, 0.012673843652009964, 0.0020246717613190413, -0.027109863236546516, 0.07844027876853943, -0.01014805119484663, 0.019576648250222206, -0.006293583195656538, -0.048459574580192566, 0.03393078222870827, 0.04574711620807648, -0.06010681763291359, -0.025319358333945274, 0.098772794008255, -0.03859378397464752, -0.030692782253026962, 0.007690828293561935, 0.02825438603758812, -0.014603069983422756, 0.02322860062122345, -0.009207343682646751, 0.06143806502223015, 0.010934746824204922, -0.021852808073163033, 0.055165521800518036, -0.009216084145009518, -0.06793341785669327, 0.04533921554684639, -0.048080924898386, 0.022592894732952118, 0.033068541437387466, 0.007881698198616505, 0.025631239637732506, 0.00029027427081018686, -0.0561126172542572, -0.007271422538906336, 0.05536895617842674, -0.05115574970841408, 0.042531296610832214, -0.025197723880410194, 0.01924409158527851, -0.0027287292759865522, -0.00507039949297905, -0.02539922297000885, -0.010107752867043018, 0.010699602775275707, -0.0382370762526989, -0.01664213091135025, -0.047089584171772, -0.0034264489077031612, -0.06642859429121017, -0.03939779847860336, 0.03560569882392883, 0.00872077327221632, -0.0036075676325708628, -0.03039364702999592, 0.012129309587180614, 0.012651915661990643, 0.009354270994663239, 0.02331015281379223, 0.05055489391088486, -0.022661900147795677, 0.00035730405943468213, 0.0005964897572994232, 0.03398798033595085, 0.04327566176652908, 0.09301704168319702, 0.03518136963248253, -0.0037557317409664392, -0.009706447832286358, -0.06007389351725578, 0.04447989538311958, -0.04350242018699646, -0.017650479450821877, 0.0252322256565094, -0.02052578330039978, -0.06214236468076706, 0.0038642920553684235, 0.02526138350367546, -0.055183544754981995, 0.0270160473883152, 0.059157948940992355, -0.0441071093082428, -0.07567676156759262, -0.006776335649192333, 0.020049113780260086, -0.044227126985788345, 0.010548129677772522, 0.023682815954089165, -0.002461734227836132, 0.029462721198797226, -0.029845494776964188, -0.033939626067876816, -0.051410745829343796, 0.024170760065317154, 0.036312803626060486, -0.06342709809541702, -0.06829972565174103, -0.040413957089185715, -0.006245889700949192, 0.03030882030725479, -0.054809100925922394, -0.06274536997079849, -0.0717836245894432, -0.0009123955969698727, 0.019963543862104416, -0.0375322587788105, 0.053369540721178055, 0.020385492593050003, 0.036196269094944, 0.0587073415517807, -0.043831467628479004, -0.021353349089622498, -0.003790423506870866, 0.009372469037771225, 0.05401051789522171, -0.02122296765446663, -0.025599410757422447, -0.024540526792407036, 0.003171027172356844, 0.01038559153676033, 0.06397418677806854, -0.007261837832629681, 0.008491010405123234, 0.01663254387676716, -0.014911594800651073, 0.050654370337724686, 0.048769839107990265, -0.015214084647595882, 0.05247296020388603, 0.057401981204748154, -0.0187350045889616, -0.01051677018404007, -0.05700339749455452, -0.0035918855573982, 0.02362808585166931, 0.0001396903971908614, -0.022886553779244423, -0.010406074114143848, 0.016252927482128143, -0.0036992013920098543, 0.005374113097786903, 0.04946666210889816, 0.012444490566849709, 0.004213852807879448, 0.03760175779461861, 0.018171176314353943, -0.0313417948782444, -0.02914964035153389, 0.02928393892943859, -0.008435087278485298, 0.0027362164109945297, 0.024289576336741447, 0.0241272933781147, -0.041493237018585205, 0.044811904430389404, 0.02856479585170746, -0.04145795851945877, -0.00435135280713439, -0.06780983507633209, -0.048362985253334045, 0.013448046520352364, -0.027512947097420692, 0.07810564339160919, 0.020975375548005104, 0.027971159666776657, 0.010725722648203373, -0.04188999906182289, 0.007460396271198988, 0.030039574950933456, 0.0516509972512722, 0.1017109751701355, -0.009912806563079357, 0.018026815727353096, 0.0017937867669388652, -0.07013159245252609, -0.03783680498600006, 0.05119873955845833, 0.02860427275300026, 0.05304737761616707, 0.032938648015260696, -0.0352289192378521, -0.05801593139767647, -0.0216226764023304, -0.04440734162926674, -0.046590279787778854, -0.003245545784011483, 0.011306687258183956, 0.03128962218761444, 0.033387381583452225, 0.04227878898382187, -0.03396473452448845, 0.006454319227486849, -0.01405749935656786, -0.006435456685721874, 0.054593607783317566, -0.05222899839282036, 0.03660539537668228, 0.0009407529723830521, 0.03186017647385597, 0.017225367948412895, 0.05820940062403679, -0.03340610861778259, 0.0026137330569326878, } -func getEmbedding(query, projectID string) (string, error) { +const ( + EmbeddingDBName = "embeddings-pdf" + EmbeddingCollection = "Greece" + EmbeddingField = "embedding" + LimitNN = 5 + ContentField = "content" +) + +func GetEmbedding(query []float32, projectID string) (string, error) { ctx := context.Background() - vectorDB, err := firestore.NewClientWithDatabase(ctx, projectID, "embedding-db") + vectorDB, err := firestore.NewClientWithDatabase(ctx, projectID, EmbeddingDBName) if err != nil { return "", err } defer vectorDB.Close() - collection := vectorDB.Collection("Histories") - vectorQuery := collection.FindNearest("embedding", - testQuery, - 5, + collection := vectorDB.Collection(EmbeddingCollection) + vectorQuery := collection.FindNearest(EmbeddingField, + query, + LimitNN, firestore.DistanceMeasureEuclidean, nil) docs, err := vectorQuery.Documents(ctx).GetAll() @@ -31,7 +39,7 @@ func getEmbedding(query, projectID string) (string, error) { output := "" for _, doc := range docs { - output += doc.Data()["content"].(string) + output += doc.Data()[ContentField].(string) } return output, nil } diff --git a/server/main.go b/server/main.go index 7ceef78..73ca681 100644 --- a/server/main.go +++ b/server/main.go @@ -9,6 +9,7 @@ import ( "time" ai "github.com/telpirion/MyHerodotus/ai" + db "github.com/telpirion/MyHerodotus/databases" "github.com/telpirion/MyHerodotus/generated" "github.com/gin-gonic/gin" @@ -94,7 +95,7 @@ func startConversation(c *gin.Context) { LogInfo("Start conversation request received") // create a new conversation context - convoHistory, err := getConversation(encryptedEmail, projectID) + convoHistory, err := db.GetConversation(encryptedEmail, projectID) if err != nil { LogError(fmt.Sprintf("couldn't get conversation history: %v\n", err)) } @@ -150,7 +151,10 @@ func respondToUser(c *gin.Context) { botResponse, promptTemplateName, err = ai.Predict(userMsg.Message, userMsg.Model, projectID) if err != nil { LogError(fmt.Sprintf("bad response from %s: %v\n", userMsg.Model, err)) - botResponse = "Oops! I had troubles understanding that ..." + c.JSON(http.StatusOK, gin.H{ + "Message": "Oops! I had troubles understanding that ...", + }) + return } // Store data in Firestore @@ -208,7 +212,7 @@ func updateDatabase(projectID, userMessage, modelName, promptTemplateName, botRe // Store the conversation in Firestore and update the cachedContext // This is dual-entry accounting so that we don't have to query Firestore // every time to update the cached context - documentID, err := saveConversation(*convo, encryptedEmail, projectID) + documentID, err := db.SaveConversation(*convo, encryptedEmail, projectID) if err != nil { return "", fmt.Errorf("couldn't save conversation: %v", err) } @@ -251,7 +255,7 @@ func rateResponse(c *gin.Context) { return } - err = updateConversation(userRating.DocumentID, encryptedEmail, userRating.UserRating, projectID) + err = db.UpdateConversation(userRating.DocumentID, encryptedEmail, userRating.UserRating, projectID) if err != nil { LogError(err.Error()) c.JSON(http.StatusBadRequest, gin.H{ diff --git a/site/js/index.js b/site/js/index.js index c812ab7..2d6df44 100644 --- a/site/js/index.js +++ b/site/js/index.js @@ -69,7 +69,11 @@ function processRating(e) { } function processForm(e) { - //e.preventDefault(); + // Collect data and return early if there is no message from user. + const message = document.getElementById("userMsg").value; + if (message === "") { + return; + } // Emit 'msg' event for bot progress bar const event = new Event("msg"); @@ -84,12 +88,6 @@ function processForm(e) { window.location = `/?status=unauthorized`; } - // Collect data - const message = document.getElementById("userMsg").value; - if (message === "") { - return; - } - const selection = document.getElementById("modelSelect"); const model = selection.options[selection.selectedIndex].text; From 08dea6b9762f76cc31350fa50b2360dc375344c0 Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Fri, 6 Dec 2024 14:00:04 -0800 Subject: [PATCH 5/6] fix Dockerfile --- Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile b/Dockerfile index f179bbe..96093b1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,6 +18,7 @@ COPY prompts ./server/templates COPY server/favicon.ico ./server/favicon.ico COPY server/generated ./server/generated COPY server/ai ./server/ai +COPY server/databases ./server/databases COPY server/*.go ./server COPY server/go.mod server/go.sum ./server/ From 2a948b421e374ba06886f1f157df2560072845ea Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Fri, 6 Dec 2024 14:02:39 -0800 Subject: [PATCH 6/6] Added embeddings to eval list --- services/evaluations/evaluations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/services/evaluations/evaluations.py b/services/evaluations/evaluations.py index eb88e3c..f1ccc09 100644 --- a/services/evaluations/evaluations.py +++ b/services/evaluations/evaluations.py @@ -64,7 +64,8 @@ def main(): ("gemini", "gemini_1_5_flash_001"), ("gemini-tuned", "tuned_gemini"), ("gemma", "gemma"), - ("agent-assisted", "agent-assisted") + ("agent-assisted", "agent-assisted"), + ("embeddings-assisted", "embeddings-assisted") ] for m in models: model_id, model_name = m