Skip to content

Commit 5f2ce78

Browse files
author
Mark Nelson
committed
Chat memory interface and simple implementation.
Signed-off-by: Mark Nelson <[email protected]>
1 parent e39e727 commit 5f2ce78

File tree

6 files changed

+1527
-0
lines changed

6 files changed

+1527
-0
lines changed

common/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ add_library(${TARGET} STATIC
5858
base64.hpp
5959
chat.cpp
6060
chat.h
61+
chat-memory/chat_memory.cpp
62+
chat-memory/chat_memory.h
63+
chat-memory/chat_memory_simple.cpp
64+
chat-memory/chat_memory_simple.h
65+
chat-memory/chat_memory_factory.cpp
6166
common.cpp
6267
common.h
6368
console.cpp

common/chat-memory/chat_memory.cpp

+323
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
// chat_memory.cpp
2+
#include "chat_memory.h"
3+
#include <iostream>
4+
#include <regex>
5+
#include <ctime>
6+
7+
void ChatMemoryCommon::process_response(json& response, bool is_final, const WriteCallback& write_callback) {
8+
// For streaming responses
9+
if (is_streaming_response(response)) {
10+
// Process the chunk normally
11+
process_streaming_chunk(response);
12+
13+
// On final chunk, check if we need to execute memory commands
14+
if (is_final) {
15+
// Extract memory commands from the accumulated content
16+
std::regex json_pattern(R"(\{[^{}]*"memory_command"[^{}]*\})");
17+
std::smatch match;
18+
19+
if (std::regex_search(accumulated_content, match, json_pattern)) {
20+
std::string json_str = match.str();
21+
22+
// Execute the memory command
23+
std::string memory_response = parse_and_execute_command(json_str);
24+
25+
if (!memory_response.empty()) {
26+
// Create a JSON response with the memory results
27+
nlohmann::ordered_json memory_chunk = {
28+
{"id", "memory_response"},
29+
{"object", "chat.completion.chunk"},
30+
{"created", (int)time(NULL)},
31+
{"model", "memory_system"},
32+
{"choices", {{
33+
{"index", 0},
34+
{"delta", {{"content", "\n\n" + memory_response}}},
35+
{"finish_reason", nullptr}
36+
}}}
37+
};
38+
39+
// Format and send the response
40+
std::string chunk_str = "data: " + memory_chunk.dump() + "\n\n";
41+
write_callback(chunk_str.c_str(), chunk_str.size());
42+
}
43+
44+
// Signal the end of the stream
45+
const std::string done_msg = "data: [DONE]\n\n";
46+
write_callback(done_msg.c_str(), done_msg.size());
47+
} else {
48+
// No memory command detected, just end the stream normally
49+
const std::string done_msg = "data: [DONE]\n\n";
50+
write_callback(done_msg.c_str(), done_msg.size());
51+
}
52+
53+
// Reset streaming state
54+
reset_streaming();
55+
}
56+
} else {
57+
// For non-streaming responses, process directly
58+
process_regular_response(response);
59+
}
60+
}
61+
62+
// Logging functions implementations
63+
bool ChatMemoryCommon::is_debug_enabled() {
64+
static bool checked = false;
65+
static bool enabled = false;
66+
67+
if (!checked) {
68+
checked = true;
69+
// Check environment variable first
70+
const char* debug_env = std::getenv("LLAMA_MEMORY_DEBUG");
71+
if (debug_env && (std::string(debug_env) == "1" || std::string(debug_env) == "true")) {
72+
enabled = true;
73+
} else {
74+
// Check compile-time flag
75+
enabled = CHAT_MEMORY_DEBUG != 0;
76+
}
77+
}
78+
return enabled;
79+
}
80+
81+
void ChatMemoryCommon::log_debug(const std::string& message) const {
82+
if (!is_debug_enabled()) return;
83+
84+
// Get current time for timestamp
85+
auto now = std::time(nullptr);
86+
auto tm = *std::localtime(&now);
87+
std::ostringstream timestamp;
88+
timestamp << std::put_time(&tm, "%Y-%m-%d %H:%M:%S");
89+
90+
std::cerr << "[" << timestamp.str() << "] [ChatMemory Debug] " << message << std::endl;
91+
}
92+
93+
void ChatMemoryCommon::log_command(const std::string& command, const nlohmann::ordered_json& response) const {
94+
if (!ChatMemoryCommon::is_debug_enabled()) return;
95+
96+
ChatMemoryCommon::log_debug("Command executed: " + command);
97+
ChatMemoryCommon::log_debug("Response: " + response.dump(2));
98+
}
99+
100+
bool ChatMemoryCommon::is_streaming_response(const json& j) const {
101+
return j.contains("object") && j["object"].get<std::string>() == "chat.completion.chunk";
102+
}
103+
104+
void ChatMemoryCommon::track_response(const std::string& response) {
105+
ChatMemoryCommon::log_debug("track_response: Adding response with size " + std::to_string(response.size()) + " bytes");
106+
107+
recent_responses.push_back(response);
108+
if (recent_responses.size() > max_context_responses) {
109+
ChatMemoryCommon::log_debug("track_response: Removing oldest response (exceeded max_context_responses)");
110+
recent_responses.pop_front();
111+
}
112+
}
113+
114+
// Check if a valid memory command JSON is being used
115+
bool ChatMemoryCommon::is_valid_memory_json(const std::string& output) const {
116+
ChatMemoryCommon::log_debug("is_valid_memory_json: Checking if \"" + output.substr(0, std::min(output.size(), size_t(50))) +
117+
(output.size() > 50 ? "..." : "") + "\" contains valid memory command JSON");
118+
119+
// Look for valid memory_command JSON pattern
120+
std::regex memory_cmd_pattern(R"(\{"memory_command":[^}]+\})");
121+
bool valid = std::regex_search(output, memory_cmd_pattern);
122+
123+
ChatMemoryCommon::log_debug("is_valid_memory_json: Result = " + std::string(valid ? "valid" : "invalid") + " memory command JSON");
124+
return valid;
125+
}
126+
127+
// Main entry point for processing model output and executing commands
128+
std::string ChatMemoryCommon::parse_and_execute_command(const std::string& output) {
129+
log_debug("parse_and_execute_command: Processing output for memory commands");
130+
131+
if (output.find("memory_command") == std::string::npos || output.find('{') == std::string::npos) {
132+
log_debug("parse_and_execute_command: No memory commands found");
133+
return ""; // No memory commands found
134+
}
135+
136+
// Check if this appears to be a valid JSON command structure
137+
if (!is_valid_memory_json(output)) {
138+
log_debug("parse_and_execute_command: Warning - Detected memory-related text without proper JSON format");
139+
std::cerr << "[ChatMemory] Warning: Detected memory-related text without proper JSON format.\n";
140+
// Continue anyway as regex might not catch all valid formats
141+
}
142+
143+
std::regex json_block(R"(\{[^{}]*(\{[^{}]*\}[^{}]*)*\})");
144+
auto begin = std::sregex_iterator(output.begin(), output.end(), json_block);
145+
auto end = std::sregex_iterator();
146+
147+
if (begin == end) {
148+
log_debug("parse_and_execute_command: No JSON blocks found");
149+
std::cerr << "[ChatMemory] No JSON blocks found in output.\n";
150+
return "";
151+
}
152+
153+
for (auto it = begin; it != end; ++it) {
154+
const std::string json_text = it->str();
155+
if (json_text.find("memory_command") == std::string::npos) {
156+
continue;
157+
}
158+
159+
ChatMemoryCommon::log_debug("parse_and_execute_command: Found potential memory command JSON: " +
160+
json_text.substr(0, std::min(json_text.size(), size_t(100))) +
161+
(json_text.size() > 100 ? "..." : ""));
162+
163+
try {
164+
json j = json::parse(json_text);
165+
166+
// Execute the command and get the human-readable response
167+
std::string human_response = execute_json_command(j);
168+
if (!human_response.empty()) {
169+
// Track the response for context management
170+
track_response(human_response);
171+
172+
log_debug("parse_and_execute_command: Successfully executed command, returning response");
173+
return human_response;
174+
}
175+
} catch (const std::exception& e) {
176+
log_debug("parse_and_execute_command: JSON parse error: " + std::string(e.what()));
177+
std::cerr << "[ChatMemory] JSON parse error: " << e.what() << "\n";
178+
std::cerr << "[ChatMemory] Offending input: " << json_text << "\n";
179+
}
180+
}
181+
182+
log_debug("parse_and_execute_command: No valid memory commands found");
183+
return ""; // No valid commands found
184+
}
185+
186+
void ChatMemoryCommon::parse_and_execute_command_json(json& j) {
187+
log_debug("parse_and_execute_command_json: Processing JSON response");
188+
189+
std::string model_output;
190+
191+
// Handle different response formats
192+
if (j.contains("content")) {
193+
// Chat completions format
194+
model_output = j["content"].get<std::string>();
195+
log_debug("parse_and_execute_command_json: Found content field");
196+
} else if (j.contains("text")) {
197+
// Regular completions format
198+
model_output = j["text"].get<std::string>();
199+
log_debug("parse_and_execute_command_json: Found text field");
200+
} else {
201+
// No recognizable output format
202+
log_debug("parse_and_execute_command_json: No recognizable output format");
203+
return;
204+
}
205+
206+
// Process and append any memory responses
207+
std::string memory_response = parse_and_execute_command(model_output);
208+
if (!memory_response.empty()) {
209+
log_debug("parse_and_execute_command_json: Found memory response, appending to output");
210+
211+
// Update the appropriate field
212+
if (j.contains("content")) {
213+
j["content"] = model_output + "\n" + memory_response;
214+
} else if (j.contains("text")) {
215+
j["text"] = model_output + "\n" + memory_response;
216+
}
217+
} else {
218+
log_debug("parse_and_execute_command_json: No memory response to append");
219+
}
220+
}
221+
222+
void ChatMemoryCommon::process_streaming_chunk(json& j) {
223+
try {
224+
if (j.is_array() && !j.empty()) {
225+
const auto& choices = j[0]["choices"];
226+
if (choices.is_array() && !choices.empty()) {
227+
const auto& delta = choices[0]["delta"];
228+
if (delta.contains("content")) {
229+
std::string content = delta["content"].get<std::string>();
230+
231+
// Just accumulate the content without modifying it
232+
accumulated_content += content;
233+
log_debug("Chunk appended: '" + content + "'");
234+
return;
235+
}
236+
}
237+
}
238+
log_debug("Chunk missing 'content' field: " + j.dump());
239+
} catch (const std::exception &e) {
240+
log_debug(std::string("Exception parsing chunk: ") + e.what());
241+
}
242+
}
243+
244+
void ChatMemoryCommon::process_regular_response(json& j) {
245+
log_debug("process_regular_response: Processing standard response format");
246+
247+
std::string model_output;
248+
bool found_content = false;
249+
250+
// Handle different response formats
251+
if (j.contains("choices") && j["choices"].is_array() && !j["choices"].empty()) {
252+
auto& first_choice = j["choices"][0];
253+
254+
if (first_choice.contains("message") && first_choice["message"].contains("content")) {
255+
model_output = first_choice["message"]["content"].get<std::string>();
256+
found_content = true;
257+
log_debug("process_regular_response: Found content in OpenAI format: \"" +
258+
model_output.substr(0, std::min(model_output.size(), size_t(100))) +
259+
(model_output.size() > 100 ? "..." : "") + "\"");
260+
} else {
261+
log_debug("process_regular_response: No content found in OpenAI format");
262+
}
263+
} else if (j.contains("content")) {
264+
model_output = j["content"].get<std::string>();
265+
found_content = true;
266+
log_debug("process_regular_response: Found content field: \"" +
267+
model_output.substr(0, std::min(model_output.size(), size_t(100))) +
268+
(model_output.size() > 100 ? "..." : "") + "\"");
269+
} else if (j.contains("text")) {
270+
model_output = j["text"].get<std::string>();
271+
found_content = true;
272+
log_debug("process_regular_response: Found text field: \"" +
273+
model_output.substr(0, std::min(model_output.size(), size_t(100))) +
274+
(model_output.size() > 100 ? "..." : "") + "\"");
275+
} else {
276+
log_debug("process_regular_response: No recognizable output format. JSON structure: " +
277+
j.dump().substr(0, std::min(j.dump().size(), size_t(500))) +
278+
(j.dump().size() > 500 ? "..." : ""));
279+
return;
280+
}
281+
282+
if (!found_content || model_output.empty()) {
283+
log_debug("process_regular_response: No model output found to process");
284+
return;
285+
}
286+
287+
// Process and append any memory responses
288+
std::string memory_response = parse_and_execute_command(model_output);
289+
if (!memory_response.empty()) {
290+
log_debug("process_regular_response: Found memory response, appending to output");
291+
292+
// Update the appropriate field
293+
if (j.contains("choices") && j["choices"].is_array() && !j["choices"].empty()) {
294+
auto& first_choice = j["choices"][0];
295+
if (first_choice.contains("message") && first_choice["message"].contains("content")) {
296+
first_choice["message"]["content"] = model_output + "\n" + memory_response;
297+
log_debug("process_regular_response: Updated content in OpenAI format");
298+
} else {
299+
log_debug("process_regular_response: Couldn't update content in OpenAI format");
300+
}
301+
} else if (j.contains("content")) {
302+
j["content"] = model_output + "\n" + memory_response;
303+
log_debug("process_regular_response: Updated content field");
304+
} else if (j.contains("text")) {
305+
j["text"] = model_output + "\n" + memory_response;
306+
log_debug("process_regular_response: Updated text field");
307+
} else {
308+
log_debug("process_regular_response: Couldn't find field to update with memory response");
309+
}
310+
} else {
311+
log_debug("process_regular_response: No memory response to append");
312+
}
313+
}
314+
315+
void ChatMemoryCommon::reset_streaming() {
316+
log_debug("reset_streaming: Resetting streaming state");
317+
in_streaming_mode = false;
318+
accumulated_content.clear();
319+
}
320+
321+
std::string ChatMemoryCommon::execute_json_command(nlohmann::ordered_json &j) {
322+
return "ChatMemoryCommon";
323+
}

0 commit comments

Comments
 (0)