Skip to content

Commit 9765576

Browse files
author
Mark Nelson
committed
Chat memory interface and simple implementation.
Signed-off-by: Mark Nelson <[email protected]>
1 parent e39e727 commit 9765576

File tree

6 files changed

+1548
-0
lines changed

6 files changed

+1548
-0
lines changed

common/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ add_library(${TARGET} STATIC
5858
base64.hpp
5959
chat.cpp
6060
chat.h
61+
chat-memory/chat_memory.cpp
62+
chat-memory/chat_memory.h
63+
chat-memory/chat_memory_simple.cpp
64+
chat-memory/chat_memory_simple.h
65+
chat-memory/chat_memory_factory.cpp
6166
common.cpp
6267
common.h
6368
console.cpp

common/chat-memory/chat_memory.cpp

+344
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
// chat_memory.cpp
2+
#include "chat_memory.h"
3+
#include <iostream>
4+
#include <regex>
5+
#include <ctime>
6+
7+
void ChatMemoryCommon::process_response(json& response, bool is_final, const WriteCallback& write_callback) {
8+
// For streaming responses
9+
if (is_streaming_response(response)) {
10+
// Process the chunk normally
11+
process_streaming_chunk(response);
12+
13+
// On final chunk, check if we need to execute memory commands
14+
if (is_final) {
15+
// Extract memory commands from the accumulated content
16+
std::regex json_pattern(R"(\{[^{}]*"memory_command"[^{}]*\})");
17+
std::smatch match;
18+
19+
if (std::regex_search(accumulated_content, match, json_pattern)) {
20+
std::string json_str = match.str();
21+
22+
// Execute the memory command
23+
std::string memory_response = parse_and_execute_command(json_str);
24+
25+
if (!memory_response.empty()) {
26+
// Create a JSON response with the memory results
27+
nlohmann::ordered_json memory_chunk = {
28+
{"id", "memory_response"},
29+
{"object", "chat.completion.chunk"},
30+
{"created", (int)time(NULL)},
31+
{"model", "memory_system"},
32+
{"choices", {{
33+
{"index", 0},
34+
{"delta", {{"content", "\n\n" + memory_response}}},
35+
{"finish_reason", nullptr}
36+
}}}
37+
};
38+
39+
// Format and send the response
40+
std::string chunk_str = "data: " + memory_chunk.dump() + "\n\n";
41+
write_callback(chunk_str.c_str(), chunk_str.size());
42+
}
43+
44+
// Signal the end of the stream
45+
const std::string done_msg = "data: [DONE]\n\n";
46+
write_callback(done_msg.c_str(), done_msg.size());
47+
} else {
48+
// No memory command detected, just end the stream normally
49+
const std::string done_msg = "data: [DONE]\n\n";
50+
write_callback(done_msg.c_str(), done_msg.size());
51+
}
52+
53+
// Reset streaming state
54+
reset_streaming();
55+
}
56+
} else {
57+
// For non-streaming responses, process directly
58+
process_regular_response(response);
59+
}
60+
}
61+
62+
// Logging functions implementations
63+
bool ChatMemoryCommon::is_debug_enabled() {
64+
static bool checked = false;
65+
static bool enabled = false;
66+
67+
if (!checked) {
68+
checked = true;
69+
// Check environment variable first
70+
const char* debug_env = std::getenv("LLAMA_MEMORY_DEBUG");
71+
if (debug_env && (std::string(debug_env) == "1" || std::string(debug_env) == "true")) {
72+
enabled = true;
73+
} else {
74+
// Check compile-time flag
75+
enabled = CHAT_MEMORY_DEBUG != 0;
76+
}
77+
}
78+
return enabled;
79+
}
80+
81+
void ChatMemoryCommon::log_debug(const std::string& message) const {
82+
if (!is_debug_enabled()) return;
83+
84+
// Get current time for timestamp
85+
auto now = std::time(nullptr);
86+
auto tm = *std::localtime(&now);
87+
std::ostringstream timestamp;
88+
timestamp << std::put_time(&tm, "%Y-%m-%d %H:%M:%S");
89+
90+
std::cerr << "[" << timestamp.str() << "] [ChatMemory Debug] " << message << std::endl;
91+
}
92+
93+
void ChatMemoryCommon::log_command(const std::string& command, const nlohmann::ordered_json& response) const {
94+
if (!ChatMemoryCommon::is_debug_enabled()) return;
95+
96+
ChatMemoryCommon::log_debug("Command executed: " + command);
97+
ChatMemoryCommon::log_debug("Response: " + response.dump(2));
98+
}
99+
100+
bool ChatMemoryCommon::is_streaming_response(const json& j) const {
101+
// Check if it's a direct object with the right type
102+
if (j.contains("object") && j["object"].get<std::string>() == "chat.completion.chunk") {
103+
return true;
104+
}
105+
106+
// Check if it's an array containing objects with the right type
107+
if (j.is_array() && !j.empty() && j[0].contains("object") &&
108+
j[0]["object"].get<std::string>() == "chat.completion.chunk") {
109+
return true;
110+
}
111+
112+
return false;
113+
}
114+
115+
void ChatMemoryCommon::track_response(const std::string& response) {
116+
ChatMemoryCommon::log_debug("track_response: Adding response with size " + std::to_string(response.size()) + " bytes");
117+
118+
recent_responses.push_back(response);
119+
if (recent_responses.size() > max_context_responses) {
120+
ChatMemoryCommon::log_debug("track_response: Removing oldest response (exceeded max_context_responses)");
121+
recent_responses.pop_front();
122+
}
123+
}
124+
125+
// Check if a valid memory command JSON is being used
126+
bool ChatMemoryCommon::is_valid_memory_json(const std::string& output) const {
127+
ChatMemoryCommon::log_debug("is_valid_memory_json: Checking if \"" + output.substr(0, std::min(output.size(), size_t(50))) +
128+
(output.size() > 50 ? "..." : "") + "\" contains valid memory command JSON");
129+
130+
// Look for valid memory_command JSON pattern
131+
std::regex memory_cmd_pattern(R"(\{"memory_command":[^}]+\})");
132+
bool valid = std::regex_search(output, memory_cmd_pattern);
133+
134+
ChatMemoryCommon::log_debug("is_valid_memory_json: Result = " + std::string(valid ? "valid" : "invalid") + " memory command JSON");
135+
return valid;
136+
}
137+
138+
// Main entry point for processing model output and executing commands
139+
std::string ChatMemoryCommon::parse_and_execute_command(const std::string& output) {
140+
log_debug("parse_and_execute_command: Processing output for memory commands");
141+
142+
if (output.find("memory_command") == std::string::npos || output.find('{') == std::string::npos) {
143+
log_debug("parse_and_execute_command: No memory commands found");
144+
return ""; // No memory commands found
145+
}
146+
147+
// Check if this appears to be a valid JSON command structure
148+
if (!is_valid_memory_json(output)) {
149+
log_debug("parse_and_execute_command: Warning - Detected memory-related text without proper JSON format");
150+
std::cerr << "[ChatMemory] Warning: Detected memory-related text without proper JSON format.\n";
151+
// Continue anyway as regex might not catch all valid formats
152+
}
153+
154+
std::regex json_block(R"(\{[^{}]*(\{[^{}]*\}[^{}]*)*\})");
155+
auto begin = std::sregex_iterator(output.begin(), output.end(), json_block);
156+
auto end = std::sregex_iterator();
157+
158+
if (begin == end) {
159+
log_debug("parse_and_execute_command: No JSON blocks found");
160+
std::cerr << "[ChatMemory] No JSON blocks found in output.\n";
161+
return "";
162+
}
163+
164+
for (auto it = begin; it != end; ++it) {
165+
const std::string json_text = it->str();
166+
if (json_text.find("memory_command") == std::string::npos) {
167+
continue;
168+
}
169+
170+
ChatMemoryCommon::log_debug("parse_and_execute_command: Found potential memory command JSON: " +
171+
json_text.substr(0, std::min(json_text.size(), size_t(100))) +
172+
(json_text.size() > 100 ? "..." : ""));
173+
174+
try {
175+
json j = json::parse(json_text);
176+
177+
// Execute the command and get the human-readable response
178+
std::string human_response = execute_json_command(j);
179+
if (!human_response.empty()) {
180+
// Track the response for context management
181+
track_response(human_response);
182+
183+
log_debug("parse_and_execute_command: Successfully executed command, returning response");
184+
return human_response;
185+
}
186+
} catch (const std::exception& e) {
187+
log_debug("parse_and_execute_command: JSON parse error: " + std::string(e.what()));
188+
std::cerr << "[ChatMemory] JSON parse error: " << e.what() << "\n";
189+
std::cerr << "[ChatMemory] Offending input: " << json_text << "\n";
190+
}
191+
}
192+
193+
log_debug("parse_and_execute_command: No valid memory commands found");
194+
return ""; // No valid commands found
195+
}
196+
197+
void ChatMemoryCommon::parse_and_execute_command_json(json& j) {
198+
log_debug("parse_and_execute_command_json: Processing JSON response");
199+
200+
std::string model_output;
201+
202+
// Handle different response formats
203+
if (j.contains("content")) {
204+
// Chat completions format
205+
model_output = j["content"].get<std::string>();
206+
log_debug("parse_and_execute_command_json: Found content field");
207+
} else if (j.contains("text")) {
208+
// Regular completions format
209+
model_output = j["text"].get<std::string>();
210+
log_debug("parse_and_execute_command_json: Found text field");
211+
} else {
212+
// No recognizable output format
213+
log_debug("parse_and_execute_command_json: No recognizable output format");
214+
return;
215+
}
216+
217+
// Process and append any memory responses
218+
std::string memory_response = parse_and_execute_command(model_output);
219+
if (!memory_response.empty()) {
220+
log_debug("parse_and_execute_command_json: Found memory response, appending to output");
221+
222+
// Update the appropriate field
223+
if (j.contains("content")) {
224+
j["content"] = model_output + "\n" + memory_response;
225+
} else if (j.contains("text")) {
226+
j["text"] = model_output + "\n" + memory_response;
227+
}
228+
} else {
229+
log_debug("parse_and_execute_command_json: No memory response to append");
230+
}
231+
}
232+
233+
void ChatMemoryCommon::process_streaming_chunk(json& j) {
234+
try {
235+
// First check if it's a direct object with choices
236+
if (j.contains("choices") && j["choices"].is_array() && !j["choices"].empty()) {
237+
const auto& first_choice = j["choices"][0];
238+
if (first_choice.contains("delta") && first_choice["delta"].contains("content")) {
239+
std::string content = first_choice["delta"]["content"].get<std::string>();
240+
accumulated_content += content;
241+
log_debug("Chunk appended: '" + content + "'");
242+
return;
243+
}
244+
}
245+
// Then check the array case
246+
else if (j.is_array() && !j.empty()) {
247+
const auto& choices = j[0]["choices"];
248+
if (choices.is_array() && !choices.empty()) {
249+
const auto& delta = choices[0]["delta"];
250+
if (delta.contains("content")) {
251+
std::string content = delta["content"].get<std::string>();
252+
accumulated_content += content;
253+
log_debug("Chunk appended: '" + content + "'");
254+
return;
255+
}
256+
}
257+
}
258+
259+
log_debug("Chunk missing 'content' field: " + j.dump());
260+
} catch (const std::exception &e) {
261+
log_debug(std::string("Exception parsing chunk: ") + e.what());
262+
}
263+
}
264+
265+
void ChatMemoryCommon::process_regular_response(json& j) {
266+
log_debug("process_regular_response: Processing standard response format");
267+
268+
std::string model_output;
269+
bool found_content = false;
270+
271+
// Handle different response formats
272+
if (j.contains("choices") && j["choices"].is_array() && !j["choices"].empty()) {
273+
auto& first_choice = j["choices"][0];
274+
275+
if (first_choice.contains("message") && first_choice["message"].contains("content")) {
276+
model_output = first_choice["message"]["content"].get<std::string>();
277+
found_content = true;
278+
log_debug("process_regular_response: Found content in OpenAI format: \"" +
279+
model_output.substr(0, std::min(model_output.size(), size_t(100))) +
280+
(model_output.size() > 100 ? "..." : "") + "\"");
281+
} else {
282+
log_debug("process_regular_response: No content found in OpenAI format");
283+
}
284+
} else if (j.contains("content")) {
285+
model_output = j["content"].get<std::string>();
286+
found_content = true;
287+
log_debug("process_regular_response: Found content field: \"" +
288+
model_output.substr(0, std::min(model_output.size(), size_t(100))) +
289+
(model_output.size() > 100 ? "..." : "") + "\"");
290+
} else if (j.contains("text")) {
291+
model_output = j["text"].get<std::string>();
292+
found_content = true;
293+
log_debug("process_regular_response: Found text field: \"" +
294+
model_output.substr(0, std::min(model_output.size(), size_t(100))) +
295+
(model_output.size() > 100 ? "..." : "") + "\"");
296+
} else {
297+
log_debug("process_regular_response: No recognizable output format. JSON structure: " +
298+
j.dump().substr(0, std::min(j.dump().size(), size_t(500))) +
299+
(j.dump().size() > 500 ? "..." : ""));
300+
return;
301+
}
302+
303+
if (!found_content || model_output.empty()) {
304+
log_debug("process_regular_response: No model output found to process");
305+
return;
306+
}
307+
308+
// Process and append any memory responses
309+
std::string memory_response = parse_and_execute_command(model_output);
310+
if (!memory_response.empty()) {
311+
log_debug("process_regular_response: Found memory response, appending to output");
312+
313+
// Update the appropriate field
314+
if (j.contains("choices") && j["choices"].is_array() && !j["choices"].empty()) {
315+
auto& first_choice = j["choices"][0];
316+
if (first_choice.contains("message") && first_choice["message"].contains("content")) {
317+
first_choice["message"]["content"] = model_output + "\n" + memory_response;
318+
log_debug("process_regular_response: Updated content in OpenAI format");
319+
} else {
320+
log_debug("process_regular_response: Couldn't update content in OpenAI format");
321+
}
322+
} else if (j.contains("content")) {
323+
j["content"] = model_output + "\n" + memory_response;
324+
log_debug("process_regular_response: Updated content field");
325+
} else if (j.contains("text")) {
326+
j["text"] = model_output + "\n" + memory_response;
327+
log_debug("process_regular_response: Updated text field");
328+
} else {
329+
log_debug("process_regular_response: Couldn't find field to update with memory response");
330+
}
331+
} else {
332+
log_debug("process_regular_response: No memory response to append");
333+
}
334+
}
335+
336+
void ChatMemoryCommon::reset_streaming() {
337+
log_debug("reset_streaming: Resetting streaming state");
338+
in_streaming_mode = false;
339+
accumulated_content.clear();
340+
}
341+
342+
std::string ChatMemoryCommon::execute_json_command(nlohmann::ordered_json &j) {
343+
return "ChatMemoryCommon";
344+
}

0 commit comments

Comments
 (0)