Skip to content

Commit bfcce4d

Browse files
authored
tool-call: support Command R7B (+ return tool_plan "thoughts" in API) (#11585)
* `tool-call`: support Command R7B (w/ tool_plan return) * `tool-call`: cleaner preservation of tokens + warn when likely bad chat template override * `tool-call`: test cleanup / handle lazy grammar triggers
1 parent 6980448 commit bfcce4d

File tree

8 files changed

+420
-56
lines changed

8 files changed

+420
-56
lines changed

common/chat.cpp

+84-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ std::string common_chat_format_name(common_chat_format format) {
1616
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
1717
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
1818
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
19+
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
1920
default:
2021
throw std::runtime_error("Unknown chat format");
2122
}
@@ -317,6 +318,79 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
317318
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
318319
}
319320

321+
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
322+
common_chat_params data;
323+
data.grammar_lazy = inputs.tool_choice != "required";
324+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
325+
auto schemas = json::array();
326+
foreach_function(inputs.tools, [&](const json & tool) {
327+
const auto & function = tool["function"];
328+
schemas.push_back({
329+
{"type", "object"},
330+
{"properties", {
331+
{"tool_call_id", {
332+
{"type", "string"},
333+
// Command-R's template expects an integer string.
334+
{"pattern", "^[0-9]{1,10}$"},
335+
}},
336+
{"tool_name", {
337+
{"type", "string"},
338+
{"const", function["name"]},
339+
}},
340+
{"parameters", function["parameters"]},
341+
}},
342+
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
343+
});
344+
});
345+
auto schema = json {
346+
{"type", "array"},
347+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
348+
{"minItems", 1},
349+
};
350+
if (!inputs.parallel_tool_calls) {
351+
schema["maxItems"] = 1;
352+
}
353+
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
354+
}, grammar_options);
355+
data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
356+
data.preserved_tokens = {
357+
"<|START_RESPONSE|>",
358+
"<|END_RESPONSE|>",
359+
"<|START_THINKING|>",
360+
"<|END_THINKING|>",
361+
"<|END_ACTION|>",
362+
};
363+
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
364+
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
365+
return data;
366+
}
367+
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
368+
static std::regex response_regex("<\\|START_RESPONSE\\|>(.*?)<\\|END_RESPONSE\\|>");
369+
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
370+
std::smatch match;
371+
372+
common_chat_msg result;
373+
result.role = "assistant";
374+
if (std::regex_match(input, match, response_regex)) {
375+
result.content = match[1].str();
376+
} else if (std::regex_match(input, match, thought_action_regex)) {
377+
result.tool_plan = match[1].str();
378+
auto actions_str = match[2].str();
379+
auto actions = json::parse(actions_str);
380+
for (const auto & action : actions) {
381+
result.tool_calls.push_back({
382+
/* .name = */ action["tool_name"],
383+
/* .arguments = */ action["parameters"].dump(),
384+
/* .id = */ action["tool_call_id"],
385+
});
386+
}
387+
} else {
388+
LOG_ERR("Failed to parse command_r output");
389+
result.content = input;
390+
}
391+
return result;
392+
}
393+
320394
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
321395
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
322396
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
@@ -462,6 +536,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
462536
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
463537
});
464538
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
539+
data.preserved_tokens = {
540+
"<|tool▁sep|>",
541+
"<|tool▁call▁end|>",
542+
};
465543
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
466544
}, grammar_options);
467545
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
@@ -704,8 +782,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
704782
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
705783
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
706784
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
707-
// Not really a trigger but need to print this special token to get a successful parse.
708-
data.grammar_triggers.push_back({"</tool_call>", /* .at_start = */ false});
785+
data.preserved_tokens = { "</tool_call>" };
709786
}, grammar_options);
710787

711788
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
@@ -822,6 +899,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
822899
if (src.find("[TOOL_CALLS]") != std::string::npos) {
823900
return common_chat_params_init_mistral_nemo(tmpl, inputs);
824901
}
902+
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
903+
return common_chat_params_init_command_r7b(tmpl, inputs);
904+
}
825905
return common_chat_params_init_generic(tmpl, inputs);
826906
}
827907

@@ -855,6 +935,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
855935
return common_chat_parse_hermes_2_pro(input);
856936
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
857937
return common_chat_parse_firefunction_v2(input);
938+
case COMMON_CHAT_FORMAT_COMMAND_R7B:
939+
return common_chat_parse_command_r7b(input);
858940
default:
859941
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
860942
}

common/chat.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum common_chat_format {
3232
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
3333
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
3434
COMMON_CHAT_FORMAT_HERMES_2_PRO,
35+
COMMON_CHAT_FORMAT_COMMAND_R7B,
3536

3637
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
3738
};
@@ -42,6 +43,7 @@ struct common_chat_params {
4243
std::string grammar;
4344
bool grammar_lazy = false;
4445
std::vector<common_grammar_trigger> grammar_triggers;
46+
std::vector<std::string> preserved_tokens;
4547
std::vector<std::string> additional_stops;
4648
};
4749

common/common.h

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "llama-cpp.h"
66

7+
#include <set>
78
#include <string>
89
#include <vector>
910
#include <sstream>
@@ -163,6 +164,7 @@ struct common_params_sampling {
163164
bool grammar_lazy = false;
164165
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
165166
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
167+
std::set<llama_token> preserved_tokens;
166168

167169
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
168170

@@ -621,6 +623,7 @@ struct common_chat_msg {
621623
std::string role;
622624
std::string content;
623625
std::vector<common_tool_call> tool_calls;
626+
std::string tool_plan = "";
624627
};
625628

626629
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

examples/server/README.md

+15-7
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,7 @@ curl http://localhost:8080/v1/chat/completions \
11281128
- Hermes 2/3, Qwen 2.5
11291129
- Mistral Nemo
11301130
- Firefunction v2
1131+
- Command R7B
11311132
- DeepSeek R1 (WIP / seems reluctant to call any tools?)
11321133

11331134
<details>
@@ -1202,21 +1203,28 @@ curl http://localhost:8080/v1/chat/completions \
12021203
```shell
12031204
# Native support:
12041205
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
1205-
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M
1206-
llama-server --jinja -fa -hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q6_K
1206+
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
12071207
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
1208-
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
1209-
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B )
1208+
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
12101209
12111210
# Native support requires the right template for these GGUFs:
1211+
1212+
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
1213+
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
1214+
12121215
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
12131216
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
1217+
12141218
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
1215-
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/firellama-3-firefunction-v2 )
1219+
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
1220+
1221+
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
1222+
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
12161223
12171224
# Generic format support
1218-
llama-server --jinja -fa -hf bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M
1219-
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q4_K_M
1225+
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
1226+
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
1227+
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
12201228
```
12211229
12221230
- Test in CLI:

examples/server/server.cpp

+38-14
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ struct slot_params {
131131
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
132132
}
133133

134+
std::vector<std::string> grammar_trigger_words;
135+
for (const auto & trigger : sampling.grammar_trigger_words) {
136+
grammar_trigger_words.push_back(trigger.word);
137+
}
138+
134139
return json {
135140
{"n_predict", n_predict}, // Server configured n_predict
136141
{"seed", sampling.seed},
@@ -165,8 +170,9 @@ struct slot_params {
165170
{"n_probs", sampling.n_probs},
166171
{"min_keep", sampling.min_keep},
167172
{"grammar", sampling.grammar},
168-
// {"grammar_trigger_words", sampling.grammar_trigger_words},
173+
{"grammar_trigger_words", grammar_trigger_words},
169174
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
175+
{"preserved_tokens", sampling.preserved_tokens},
170176
{"samplers", samplers},
171177
{"speculative.n_max", speculative.n_max},
172178
{"speculative.n_min", speculative.n_min},
@@ -363,12 +369,26 @@ struct server_task {
363369
if (ids.size() == 1) {
364370
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
365371
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
372+
params.sampling.preserved_tokens.insert(ids[0]);
366373
continue;
367374
}
368375
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
369376
params.sampling.grammar_trigger_words.push_back(trigger);
370377
}
371378
}
379+
const auto preserved_tokens = data.find("preserved_tokens");
380+
if (preserved_tokens != data.end()) {
381+
for (const auto & t : *preserved_tokens) {
382+
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
383+
if (ids.size() == 1) {
384+
LOG_DBG("Preserved token: %d\n", ids[0]);
385+
params.sampling.preserved_tokens.insert(ids[0]);
386+
} else {
387+
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
388+
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
389+
}
390+
}
391+
}
372392
if (params.sampling.grammar_lazy) {
373393
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
374394
}
@@ -695,19 +715,19 @@ struct server_task_result_cmpl_final : server_task_result {
695715

696716
json to_json_oaicompat_chat() {
697717
std::string finish_reason = "length";
698-
common_chat_msg message;
718+
common_chat_msg msg;
699719
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
700720
LOG_DBG("Parsing chat message: %s\n", content.c_str());
701-
message = common_chat_parse(content, oaicompat_chat_format);
702-
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
721+
msg = common_chat_parse(content, oaicompat_chat_format);
722+
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
703723
} else {
704-
message.content = content;
724+
msg.content = content;
705725
}
706726

707727
json tool_calls;
708-
if (!message.tool_calls.empty()) {
728+
if (!msg.tool_calls.empty()) {
709729
tool_calls = json::array();
710-
for (const auto & tc : message.tool_calls) {
730+
for (const auto & tc : msg.tool_calls) {
711731
tool_calls.push_back({
712732
{"type", "function"},
713733
{"function", {
@@ -719,14 +739,19 @@ struct server_task_result_cmpl_final : server_task_result {
719739
}
720740
}
721741

742+
json message {
743+
{"content", msg.content},
744+
{"tool_calls", tool_calls},
745+
{"role", "assistant"},
746+
};
747+
if (!msg.tool_plan.empty()) {
748+
message["tool_plan"] = msg.tool_plan;
749+
}
750+
722751
json choice {
723752
{"finish_reason", finish_reason},
724753
{"index", 0},
725-
{"message", json {
726-
{"content", message.content},
727-
{"tool_calls", tool_calls},
728-
{"role", "assistant"},
729-
}},
754+
{"message", message},
730755
};
731756

732757
if (!stream && probs_output.size() > 0) {
@@ -2833,8 +2858,7 @@ struct server_context {
28332858
server_slot * slot_batched = nullptr;
28342859

28352860
auto accept_special_token = [&](server_slot & slot, llama_token token) {
2836-
const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens;
2837-
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
2861+
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
28382862
};
28392863

28402864
// frist, add sampled tokens from any ongoing sequences

examples/server/utils.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ static json oaicompat_completion_params_parse(
662662
});
663663
}
664664
llama_params["grammar_triggers"] = grammar_triggers;
665+
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
665666
for (const auto & stop : chat_params.additional_stops) {
666667
llama_params["stop"].push_back(stop);
667668
}

0 commit comments

Comments
 (0)