Skip to content

Commit 20d3be8

Browse files
authored
Merge pull request #149 from janhq/hotfix_embedding
Hotfix embedding
2 parents 0f66207 + 6aa879b commit 20d3be8

File tree

1 file changed

+40
-5
lines changed

1 file changed

+40
-5
lines changed

controllers/llamaCPP.cc

+40-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,40 @@ std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {
2727

2828
// --------------------------------------------
2929

30+
std::string create_embedding_payload(const std::vector<float> &embedding,
31+
int prompt_tokens) {
32+
Json::Value root;
33+
34+
root["object"] = "list";
35+
36+
Json::Value dataArray(Json::arrayValue);
37+
Json::Value dataItem;
38+
39+
dataItem["object"] = "embedding";
40+
41+
Json::Value embeddingArray(Json::arrayValue);
42+
for (const auto &value : embedding) {
43+
embeddingArray.append(value);
44+
}
45+
dataItem["embedding"] = embeddingArray;
46+
dataItem["index"] = 0;
47+
48+
dataArray.append(dataItem);
49+
root["data"] = dataArray;
50+
51+
root["model"] = "_";
52+
53+
Json::Value usage;
54+
usage["prompt_tokens"] = prompt_tokens;
55+
usage["total_tokens"] = prompt_tokens; // Assuming total tokens equals prompt
56+
// tokens in this context
57+
root["usage"] = usage;
58+
59+
Json::StreamWriterBuilder writer;
60+
writer["indentation"] = ""; // Compact output
61+
return Json::writeString(writer, root);
62+
}
63+
3064
std::string create_full_return_json(const std::string &id,
3165
const std::string &model,
3266
const std::string &content,
@@ -245,17 +279,18 @@ void llamaCPP::embedding(
245279
const auto &jsonBody = req->getJsonObject();
246280

247281
json prompt;
248-
if (jsonBody->isMember("content") != 0) {
249-
prompt = (*jsonBody)["content"].asString();
282+
if (jsonBody->isMember("input") != 0) {
283+
prompt = (*jsonBody)["input"].asString();
250284
} else {
251285
prompt = "";
252286
}
253287
const int task_id = llama.request_completion(
254288
{{"prompt", prompt}, {"n_predict", 0}}, false, true);
255289
task_result result = llama.next_result(task_id);
256-
std::string embeddingResp = result.result_json.dump();
290+
std::vector<float> embedding_result = result.result_json["embedding"];
257291
auto resp = nitro_utils::nitroHttpResponse();
258-
resp->setBody(embeddingResp);
292+
std::string embedding_resp = create_embedding_payload(embedding_result, 0);
293+
resp->setBody(embedding_resp);
259294
resp->setContentTypeString("application/json");
260295
callback(resp);
261296
return;
@@ -363,7 +398,7 @@ void llamaCPP::loadModel(
363398
llama.initialize();
364399

365400
Json::Value jsonResp;
366-
jsonResp["message"] = "Failed to load model";
401+
jsonResp["message"] = "Model loaded successfully";
367402
model_loaded = true;
368403
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);
369404

0 commit comments

Comments
 (0)