Skip to content

Commit 29a9f5f

Browse files
author
Mark Nelson
committed
examples/server: Add support for the chat memory.
Signed-off-by: Mark Nelson <[email protected]>
1 parent 4856f80 commit 29a9f5f

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

examples/server/server.cpp

+37-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "utils.hpp"
22

33
#include "arg.h"
4+
#include "chat-memory/chat_memory.h"
45
#include "common.h"
56
#include "json-schema-to-grammar.h"
67
#include "llama.h"
@@ -3911,8 +3912,21 @@ int main(int argc, char ** argv) {
39113912
auto completion_id = gen_chatcmplid();
39123913
std::vector<server_task> tasks;
39133914

3915+
std::string conv_id = "";
39143916
try {
3915-
const auto & prompt = data.at("prompt");
3917+
// Read conv_id from JSON or skip if empty.
3918+
conv_id = data.value("conv_id", "");
3919+
if (conv_id.empty()) {
3920+
SRV_INF("%s", "No conv_id provided, chat memory will be disabled.\n");
3921+
}
3922+
3923+
std::string prefix = "";
3924+
if (!conv_id.empty()) {
3925+
auto& mem = get_or_create_chat_memory(conv_id);
3926+
prefix = mem.format_injection_prompt() + "\n\n";
3927+
}
3928+
std::string prompt = prefix + data.at("prompt").get<std::string>();
3929+
39163930
// TODO: this log can become very long, put it behind a flag or think about a more compact format
39173931
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
39183932

@@ -3953,12 +3967,24 @@ int main(int argc, char ** argv) {
39533967
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
39543968
if (results.size() == 1) {
39553969
// single result
3956-
res_ok(res, results[0]->to_json());
3970+
json out = results[0]->to_json();
3971+
// Parse model output for memory commands
3972+
if (!conv_id.empty() && !results.empty()) {
3973+
auto& mem = get_or_create_chat_memory(conv_id);
3974+
mem.parse_and_execute_command_json(out);
3975+
}
3976+
res_ok(res, out);
39573977
} else {
39583978
// multiple results (multitask)
39593979
json arr = json::array();
39603980
for (auto & res : results) {
3961-
arr.push_back(res->to_json());
3981+
json out = res->to_json();
3982+
// Parse model output for memory commands from each task
3983+
if (!conv_id.empty() && !out.empty()) {
3984+
auto& mem = get_or_create_chat_memory(conv_id);
3985+
mem.parse_and_execute_command_json(out);
3986+
}
3987+
arr.push_back(out);
39623988
}
39633989
res_ok(res, arr);
39643990
}
@@ -3968,9 +3994,16 @@ int main(int argc, char ** argv) {
39683994

39693995
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
39703996
} else {
3971-
const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
3997+
const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat, conv_id](size_t, httplib::DataSink & sink) {
39723998
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
39733999
json res_json = result->to_json();
4000+
if (!conv_id.empty()) {
4001+
auto & mem = get_or_create_chat_memory(conv_id);
4002+
mem.process_response(res_json, result->is_stop(),
4003+
[&sink](const char* data, size_t size) {
4004+
sink.write(data, size);
4005+
});
4006+
}
39744007
if (res_json.is_array()) {
39754008
for (const auto & res : res_json) {
39764009
if (!server_sent_event(sink, "data", res)) {

examples/server/utils.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,11 @@ static json oaicompat_completion_params_parse(const json & body) {
530530
throw std::runtime_error("Only no echo is supported");
531531
}
532532

533+
// Added for server-side chat memory
534+
if (body.contains("conv_id")) {
535+
llama_params["conv_id"] = body["conv_id"];
536+
}
537+
533538
// Params supported by OAI but unsupported by llama.cpp
534539
static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
535540
for (const auto & param : unsupported_params) {
@@ -631,6 +636,11 @@ static json oaicompat_completion_params_parse(
631636
}
632637
llama_params["grammar_triggers"] = grammar_triggers;
633638
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
639+
640+
// Added for server-side chat memory
641+
if (body.contains("conv_id")) {
642+
llama_params["conv_id"] = body["conv_id"];
643+
}
634644
for (const auto & stop : chat_params.additional_stops) {
635645
llama_params["stop"].push_back(stop);
636646
}

0 commit comments

Comments
 (0)