Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit cab9407

Browse files
authored
Merge pull request #1873 from janhq/chore/sync-main-to-dev
chore: sync main to dev
2 parents a35a05f + df60e6b commit cab9407

File tree

11 files changed

+420
-109
lines changed

11 files changed

+420
-109
lines changed

.github/workflows/template-build-linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ jobs:
243243
with:
244244
upload_url: ${{ inputs.upload_url }}
245245
asset_path: ./engine/cortex.tar.gz
246-
asset_name: cortex-${{ inputs.new_version }}-linux${{ inputs.arch }}.tar.gz
246+
asset_name: cortex-${{ inputs.new_version }}-linux-${{ inputs.arch }}.tar.gz
247247
asset_content_type: application/zip
248248

249249
- name: Upload release assert if public provider is github

engine/controllers/engines.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,9 @@ void Engines::InstallRemoteEngine(
251251
.get("get_models_url", "")
252252
.asString();
253253

254-
if (engine.empty() || type.empty() || url.empty()) {
254+
if (engine.empty() || type.empty()) {
255255
Json::Value res;
256-
res["message"] = "Engine name, type, url are required";
256+
res["message"] = "Engine name, type are required";
257257
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
258258
resp->setStatusCode(k400BadRequest);
259259
callback(resp);

engine/controllers/models.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ void Models::StartModel(
526526
if (auto& o = (*(req->getJsonObject()))["llama_model_path"]; !o.isNull()) {
527527
auto model_path = o.asString();
528528
if (auto& mp = (*(req->getJsonObject()))["model_path"]; mp.isNull()) {
529+
mp = model_path;
529530
// Bypass if model does not exist in DB and llama_model_path exists
530531
if (std::filesystem::exists(model_path) &&
531532
!model_service_->HasModel(model_handle)) {

engine/extensions/remote-engine/remote_engine.cc

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
2727
auto* context = static_cast<StreamContext*>(userdata);
2828
std::string chunk(ptr, size * nmemb);
2929
CTL_DBG(chunk);
30-
auto check_error = json_helper::ParseJsonString(chunk);
31-
if (check_error.isMember("error")) {
30+
Json::Value check_error;
31+
Json::Reader reader;
32+
if (reader.parse(chunk, check_error)) {
3233
CTL_WRN(chunk);
3334
Json::Value status;
3435
status["is_done"] = true;
3536
status["has_error"] = true;
3637
status["is_stream"] = true;
3738
status["status_code"] = k400BadRequest;
39+
context->need_stop = false;
3840
(*context->callback)(std::move(status), std::move(check_error));
3941
return size * nmemb;
4042
}
@@ -58,7 +60,8 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
5860
status["is_done"] = true;
5961
status["has_error"] = false;
6062
status["is_stream"] = true;
61-
status["status_code"] = 200;
63+
status["status_code"] = k200OK;
64+
context->need_stop = false;
6265
(*context->callback)(std::move(status), Json::Value());
6366
break;
6467
}
@@ -169,6 +172,15 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest(
169172

170173
curl_slist_free_all(headers);
171174
curl_easy_cleanup(curl);
175+
if (context.need_stop) {
176+
CTL_DBG("No stop message received, need to stop");
177+
Json::Value status;
178+
status["is_done"] = true;
179+
status["has_error"] = false;
180+
status["is_stream"] = true;
181+
status["status_code"] = k200OK;
182+
(*context.callback)(std::move(status), Json::Value());
183+
}
172184
return response;
173185
}
174186

@@ -602,6 +614,7 @@ void RemoteEngine::HandleChatCompletion(
602614
status["status_code"] = k500InternalServerError;
603615
Json::Value error;
604616
error["error"] = "Failed to parse response";
617+
LOG_WARN << "Failed to parse response: " << response.body;
605618
callback(std::move(status), std::move(error));
606619
return;
607620
}
@@ -626,15 +639,19 @@ void RemoteEngine::HandleChatCompletion(
626639

627640
try {
628641
response_json["stream"] = false;
642+
if (!response_json.isMember("model")) {
643+
response_json["model"] = model;
644+
}
629645
response_str = renderer_.Render(template_str, response_json);
630646
} catch (const std::exception& e) {
631647
throw std::runtime_error("Template rendering error: " +
632648
std::string(e.what()));
633649
}
634650
} catch (const std::exception& e) {
635651
// Log error and potentially rethrow or handle accordingly
636-
LOG_WARN << "Error in TransformRequest: " << e.what();
637-
LOG_WARN << "Using original request body";
652+
LOG_WARN << "Error: " << e.what();
653+
LOG_WARN << "Response: " << response.body;
654+
LOG_WARN << "Using original body";
638655
response_str = response_json.toStyledString();
639656
}
640657

@@ -649,6 +666,7 @@ void RemoteEngine::HandleChatCompletion(
649666
Json::Value error;
650667
error["error"] = "Failed to parse response";
651668
callback(std::move(status), std::move(error));
669+
LOG_WARN << "Failed to parse response: " << response_str;
652670
return;
653671
}
654672

engine/extensions/remote-engine/remote_engine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct StreamContext {
2424
std::string model;
2525
extensions::TemplateRenderer& renderer;
2626
std::string stream_template;
27+
bool need_stop = true;
2728
};
2829
struct CurlResponse {
2930
std::string body;

engine/extensions/template_renderer.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
#include <regex>
88
#include <stdexcept>
99
#include "utils/logging_utils.h"
10+
#include "utils/string_utils.h"
1011
namespace extensions {
12+
1113
TemplateRenderer::TemplateRenderer() {
1214
// Configure Inja environment
1315
env_.set_trim_blocks(true);
@@ -21,7 +23,8 @@ TemplateRenderer::TemplateRenderer() {
2123
const auto& value = *args[0];
2224

2325
if (value.is_string()) {
24-
return nlohmann::json(std::string("\"") + value.get<std::string>() +
26+
return nlohmann::json(std::string("\"") +
27+
string_utils::EscapeJson(value.get<std::string>()) +
2528
"\"");
2629
}
2730
return value;
@@ -46,16 +49,14 @@ std::string TemplateRenderer::Render(const std::string& tmpl,
4649
std::string result = env_.render(tmpl, template_data);
4750

4851
// Clean up any potential double quotes in JSON strings
49-
result = std::regex_replace(result, std::regex("\\\"\\\""), "\"");
52+
// result = std::regex_replace(result, std::regex("\\\"\\\""), "\"");
5053

5154
LOG_DEBUG << "Result: " << result;
5255

53-
// Validate JSON
54-
auto parsed = nlohmann::json::parse(result);
55-
5656
return result;
5757
} catch (const std::exception& e) {
5858
LOG_ERROR << "Template rendering failed: " << e.what();
59+
LOG_ERROR << "Data: " << data.toStyledString();
5960
LOG_ERROR << "Template: " << tmpl;
6061
throw std::runtime_error(std::string("Template rendering failed: ") +
6162
e.what());
@@ -133,4 +134,4 @@ std::string TemplateRenderer::RenderFile(const std::string& template_path,
133134
e.what());
134135
}
135136
}
136-
} // namespace remote_engine
137+
} // namespace extensions

engine/services/inference_service.cc

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
1414
}
1515
function_calling_utils::PreprocessRequest(json_body);
1616
auto tool_choice = json_body->get("tool_choice", Json::Value::null);
17+
auto model_id = json_body->get("model", "").asString();
18+
if (saved_models_.find(model_id) != saved_models_.end()) {
19+
// check if model is started, if not start it first
20+
Json::Value root;
21+
root["model"] = model_id;
22+
root["engine"] = engine_type;
23+
auto ir = GetModelStatus(std::make_shared<Json::Value>(root));
24+
auto status = std::get<0>(ir)["status_code"].asInt();
25+
if (status != drogon::k200OK) {
26+
CTL_INF("Model is not loaded, start loading it: " << model_id);
27+
auto res = LoadModel(saved_models_.at(model_id));
28+
// ignore return result
29+
}
30+
}
31+
1732
auto engine_result = engine_service_->GetLoadedEngine(engine_type);
1833
if (engine_result.has_error()) {
1934
Json::Value res;
@@ -23,45 +38,42 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
2338
LOG_WARN << "Engine is not loaded yet";
2439
return cpp::fail(std::make_pair(stt, res));
2540
}
41+
42+
if (!model_id.empty()) {
43+
if (auto model_service = model_service_.lock()) {
44+
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
45+
if (metadata_ptr != nullptr &&
46+
!metadata_ptr->tokenizer->chat_template.empty()) {
47+
auto tokenizer = metadata_ptr->tokenizer;
48+
auto messages = (*json_body)["messages"];
49+
Json::Value messages_jsoncpp(Json::arrayValue);
50+
for (auto message : messages) {
51+
messages_jsoncpp.append(message);
52+
}
2653

27-
{
28-
auto model_id = json_body->get("model", "").asString();
29-
if (!model_id.empty()) {
30-
if (auto model_service = model_service_.lock()) {
31-
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
32-
if (metadata_ptr != nullptr &&
33-
!metadata_ptr->tokenizer->chat_template.empty()) {
34-
auto tokenizer = metadata_ptr->tokenizer;
35-
auto messages = (*json_body)["messages"];
36-
Json::Value messages_jsoncpp(Json::arrayValue);
37-
for (auto message : messages) {
38-
messages_jsoncpp.append(message);
39-
}
40-
41-
Json::Value tools(Json::arrayValue);
42-
Json::Value template_data_json;
43-
template_data_json["messages"] = messages_jsoncpp;
44-
// template_data_json["tools"] = tools;
45-
46-
auto prompt_result = jinja::RenderTemplate(
47-
tokenizer->chat_template, template_data_json,
48-
tokenizer->bos_token, tokenizer->eos_token,
49-
tokenizer->add_bos_token, tokenizer->add_eos_token,
50-
tokenizer->add_generation_prompt);
51-
if (prompt_result.has_value()) {
52-
(*json_body)["prompt"] = prompt_result.value();
53-
Json::Value stops(Json::arrayValue);
54-
stops.append(tokenizer->eos_token);
55-
(*json_body)["stop"] = stops;
56-
} else {
57-
CTL_ERR("Failed to render prompt: " + prompt_result.error());
58-
}
54+
Json::Value tools(Json::arrayValue);
55+
Json::Value template_data_json;
56+
template_data_json["messages"] = messages_jsoncpp;
57+
// template_data_json["tools"] = tools;
58+
59+
auto prompt_result = jinja::RenderTemplate(
60+
tokenizer->chat_template, template_data_json, tokenizer->bos_token,
61+
tokenizer->eos_token, tokenizer->add_bos_token,
62+
tokenizer->add_eos_token, tokenizer->add_generation_prompt);
63+
if (prompt_result.has_value()) {
64+
(*json_body)["prompt"] = prompt_result.value();
65+
Json::Value stops(Json::arrayValue);
66+
stops.append(tokenizer->eos_token);
67+
(*json_body)["stop"] = stops;
68+
} else {
69+
CTL_ERR("Failed to render prompt: " + prompt_result.error());
5970
}
6071
}
6172
}
6273
}
6374

64-
CTL_INF("Json body inference: " + json_body->toStyledString());
75+
76+
CTL_DBG("Json body inference: " + json_body->toStyledString());
6577

6678
auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
6779
if (!tool_choice.isNull()) {
@@ -205,6 +217,10 @@ InferResult InferenceService::LoadModel(
205217
std::get<RemoteEngineI*>(engine_result.value())
206218
->LoadModel(json_body, std::move(cb));
207219
}
220+
if (!engine_service_->IsRemoteEngine(engine_type)) {
221+
auto model_id = json_body->get("model", "").asString();
222+
saved_models_[model_id] = json_body;
223+
}
208224
return std::make_pair(stt, r);
209225
}
210226

engine/services/inference_service.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class InferenceService {
4747

4848
cpp::result<void, InferResult> HandleRouteRequest(
4949
std::shared_ptr<SyncQueue> q, std::shared_ptr<Json::Value> json_body);
50-
50+
5151
InferResult LoadModel(std::shared_ptr<Json::Value> json_body);
5252

5353
InferResult UnloadModel(const std::string& engine,
@@ -74,4 +74,6 @@ class InferenceService {
7474
private:
7575
std::shared_ptr<EngineService> engine_service_;
7676
std::weak_ptr<ModelService> model_service_;
77+
using SavedModel = std::shared_ptr<Json::Value>;
78+
std::unordered_map<std::string, SavedModel> saved_models_;
7779
};

0 commit comments

Comments
 (0)