Skip to content

Commit a3b5cff

Browse files
tharun571vgvassilev
authored andcommitted
Refactor xassist, test framework
1 parent f678e0d commit a3b5cff

File tree

2 files changed

+55
-8
lines changed

2 files changed

+55
-8
lines changed

src/xmagics/xassist.cpp

+55-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
using json = nlohmann::json;
2020

2121
// TODO: Implement xplugin to separate the magics from the main code.
22-
// TODO: Add support for open-source models.
2322
namespace xcpp
2423
{
2524
class api_key_manager
@@ -279,6 +278,53 @@ namespace xcpp
279278
}
280279
};
281280

281+
std::string escape_special_cases(const std::string& input)
282+
{
283+
std::string escaped;
284+
for (char c : input)
285+
{
286+
switch (c)
287+
{
288+
case '\\':
289+
escaped += "\\\\";
290+
break;
291+
case '\"':
292+
escaped += "\\\"";
293+
break;
294+
case '\n':
295+
escaped += "\\n";
296+
break;
297+
case '\t':
298+
escaped += "\\t";
299+
break;
300+
case '\r':
301+
escaped += "\\r";
302+
break;
303+
case '\b':
304+
escaped += "\\b";
305+
break;
306+
case '\f':
307+
escaped += "\\f";
308+
break;
309+
default:
310+
if (c < 0x20 || c > 0x7E)
311+
{
312+
// Escape non-printable ASCII characters and non-ASCII characters
313+
std::array<char, 7> buffer{};
314+
std::stringstream ss;
315+
ss << "\\u" << std::hex << std::setw(4) << std::setfill('0') << (c & 0xFFFF);
316+
escaped += ss.str();
317+
}
318+
else
319+
{
320+
escaped += c;
321+
}
322+
break;
323+
}
324+
}
325+
return escaped;
326+
}
327+
282328
std::string gemini(const std::string& cell, const std::string& key)
283329
{
284330
curl_helper curl_helper;
@@ -369,8 +415,8 @@ namespace xcpp
369415
}
370416

371417
const std::string post_data = R"({
372-
"model": [)" + model
373-
+ R"(],
418+
"model": ")" + model
419+
+ R"(",
374420
"messages": [)" + chat_message
375421
+ R"(],
376422
"temperature": 0.7
@@ -453,18 +499,21 @@ namespace xcpp
453499
}
454500
}
455501

502+
503+
const std::string prompt = escape_special_cases(cell);
504+
456505
std::string response;
457506
if (model == "gemini")
458507
{
459-
response = gemini(cell, key);
508+
response = gemini(prompt, key);
460509
}
461510
else if (model == "openai")
462511
{
463-
response = openai(cell, key);
512+
response = openai(prompt, key);
464513
}
465514
else if (model == "ollama")
466515
{
467-
response = ollama(cell);
516+
response = ollama(prompt);
468517
}
469518

470519
std::cout << response;

test/test_xcpp_kernel.py

-2
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,6 @@ def test_notebooks(self):
167167
with open(out) as f:
168168
output_nb = nbformat.read(f, as_version=4)
169169

170-
check = True
171-
172170
# Iterate over the cells in the input and output notebooks
173171
for i, (input_cell, output_cell) in enumerate(zip(input_nb.cells, output_nb.cells)):
174172
if input_cell.cell_type == 'code' and output_cell.cell_type == 'code':

0 commit comments

Comments
 (0)