Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit f613a91

Browse files
committed
embedding that is compatible with openai
1 parent 0f66207 commit f613a91

File tree

1 file changed

+45
-5
lines changed

1 file changed

+45
-5
lines changed

controllers/llamaCPP.cc

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,45 @@ std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {
2727

2828
// --------------------------------------------
2929

30+
#include <ctime>
31+
#include <json/json.h>
32+
#include <string>
33+
#include <vector>
34+
35+
std::string create_embedding_payload(const std::vector<float> &embedding,
36+
int prompt_tokens) {
37+
Json::Value root;
38+
39+
root["object"] = "list";
40+
41+
Json::Value dataArray(Json::arrayValue);
42+
Json::Value dataItem;
43+
44+
dataItem["object"] = "embedding";
45+
46+
Json::Value embeddingArray(Json::arrayValue);
47+
for (const auto &value : embedding) {
48+
embeddingArray.append(value);
49+
}
50+
dataItem["embedding"] = embeddingArray;
51+
dataItem["index"] = 0;
52+
53+
dataArray.append(dataItem);
54+
root["data"] = dataArray;
55+
56+
root["model"] = "_";
57+
58+
Json::Value usage;
59+
usage["prompt_tokens"] = prompt_tokens;
60+
usage["total_tokens"] = prompt_tokens; // Assuming total tokens equals prompt
61+
// tokens in this context
62+
root["usage"] = usage;
63+
64+
Json::StreamWriterBuilder writer;
65+
writer["indentation"] = ""; // Compact output
66+
return Json::writeString(writer, root);
67+
}
68+
3069
std::string create_full_return_json(const std::string &id,
3170
const std::string &model,
3271
const std::string &content,
@@ -245,17 +284,18 @@ void llamaCPP::embedding(
245284
const auto &jsonBody = req->getJsonObject();
246285

247286
json prompt;
248-
if (jsonBody->isMember("content") != 0) {
249-
prompt = (*jsonBody)["content"].asString();
287+
if (jsonBody->isMember("input") != 0) {
288+
prompt = (*jsonBody)["input"].asString();
250289
} else {
251290
prompt = "";
252291
}
253292
const int task_id = llama.request_completion(
254293
{{"prompt", prompt}, {"n_predict", 0}}, false, true);
255294
task_result result = llama.next_result(task_id);
256-
std::string embeddingResp = result.result_json.dump();
295+
std::vector<float> embedding_result = result.result_json["embedding"];
257296
auto resp = nitro_utils::nitroHttpResponse();
258-
resp->setBody(embeddingResp);
297+
std::string embedding_resp = create_embedding_payload(embedding_result, 0);
298+
resp->setBody(embedding_resp);
259299
resp->setContentTypeString("application/json");
260300
callback(resp);
261301
return;
@@ -363,7 +403,7 @@ void llamaCPP::loadModel(
363403
llama.initialize();
364404

365405
Json::Value jsonResp;
366-
jsonResp["message"] = "Failed to load model";
406+
jsonResp["message"] = "Model loaded successfully";
367407
model_loaded = true;
368408
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
369409

0 commit comments

Comments
 (0)