@@ -16,6 +16,7 @@ std::string common_chat_format_name(common_chat_format format) {
16
16
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return " Functionary v3.2" ;
17
17
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return " Functionary v3.1 Llama 3.1" ;
18
18
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return " Hermes 2 Pro" ;
19
+ case COMMON_CHAT_FORMAT_COMMAND_R7B: return " Command R7B" ;
19
20
default :
20
21
throw std::runtime_error (" Unknown chat format" );
21
22
}
@@ -317,6 +318,79 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
317
318
return parse_prefixed_json_tool_call_array (input, " [TOOL_CALLS]" );
318
319
}
319
320
321
+ static common_chat_params common_chat_params_init_command_r7b (const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
322
+ common_chat_params data;
323
+ data.grammar_lazy = inputs.tool_choice != " required" ;
324
+ data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
325
+ auto schemas = json::array ();
326
+ foreach_function (inputs.tools , [&](const json & tool) {
327
+ const auto & function = tool[" function" ];
328
+ schemas.push_back ({
329
+ {" type" , " object" },
330
+ {" properties" , {
331
+ {" tool_call_id" , {
332
+ {" type" , " string" },
333
+ // Command-R's template expects an integer string.
334
+ {" pattern" , " ^[0-9]{1,10}$" },
335
+ }},
336
+ {" tool_name" , {
337
+ {" type" , " string" },
338
+ {" const" , function[" name" ]},
339
+ }},
340
+ {" parameters" , function[" parameters" ]},
341
+ }},
342
+ {" required" , json::array ({" tool_call_id" , " tool_name" , " parameters" })},
343
+ });
344
+ });
345
+ auto schema = json {
346
+ {" type" , " array" },
347
+ {" items" , schemas.size () == 1 ? schemas[0 ] : json {{" anyOf" , schemas}}},
348
+ {" minItems" , 1 },
349
+ };
350
+ if (!inputs.parallel_tool_calls ) {
351
+ schema[" maxItems" ] = 1 ;
352
+ }
353
+ builder.add_rule (" root" , " \" <|START_ACTION|>\" " + builder.add_schema (" tool_calls" , schema) + " \" <|END_ACTION|>\" " );
354
+ }, grammar_options);
355
+ data.grammar_triggers .push_back ({" <|START_ACTION|>" , /* .at_start = */ false });
356
+ data.preserved_tokens = {
357
+ " <|START_RESPONSE|>" ,
358
+ " <|END_RESPONSE|>" ,
359
+ " <|START_THINKING|>" ,
360
+ " <|END_THINKING|>" ,
361
+ " <|END_ACTION|>" ,
362
+ };
363
+ data.prompt = tmpl.apply (inputs.messages , inputs.tools .empty () ? json () : inputs.tools , inputs.add_generation_prompt );
364
+ data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
365
+ return data;
366
+ }
367
+ static common_chat_msg common_chat_parse_command_r7b (const std::string & input) {
368
+ static std::regex response_regex (" <\\ |START_RESPONSE\\ |>(.*?)<\\ |END_RESPONSE\\ |>" );
369
+ static std::regex thought_action_regex (" <\\ |START_THINKING\\ |>([\\ s\\ S\\ n\\ r]*?)<\\ |END_THINKING\\ |><\\ |START_ACTION\\ |>([\\ s\\ S\\ n\\ r]*?)<\\ |END_ACTION\\ |>" );
370
+ std::smatch match;
371
+
372
+ common_chat_msg result;
373
+ result.role = " assistant" ;
374
+ if (std::regex_match (input, match, response_regex)) {
375
+ result.content = match[1 ].str ();
376
+ } else if (std::regex_match (input, match, thought_action_regex)) {
377
+ result.tool_plan = match[1 ].str ();
378
+ auto actions_str = match[2 ].str ();
379
+ auto actions = json::parse (actions_str);
380
+ for (const auto & action : actions) {
381
+ result.tool_calls .push_back ({
382
+ /* .name = */ action[" tool_name" ],
383
+ /* .arguments = */ action[" parameters" ].dump (),
384
+ /* .id = */ action[" tool_call_id" ],
385
+ });
386
+ }
387
+ } else {
388
+ LOG_ERR (" Failed to parse command_r output" );
389
+ result.content = input;
390
+ }
391
+ return result;
392
+ }
393
+
320
394
static void expect_tool_parameters (const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
321
395
if (!parameters.is_object () || !parameters.contains (" type" ) || parameters[" type" ] != " object" || !parameters.contains (" properties" ) || !parameters.contains (" required" )) {
322
396
throw std::runtime_error (" Parameters of tool " + name + " must be an object w/ required properties" );
@@ -462,6 +536,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
462
536
" \" <|tool▁call▁begin|>function<|tool▁sep|>" + name + " \\ n```json\\ n\" " + args_rule + " \" ```<|tool▁call▁end|>\" " ));
463
537
});
464
538
data.grammar_triggers .push_back ({" <|tool▁calls▁begin|>" , /* .at_start = */ false });
539
+ data.preserved_tokens = {
540
+ " <|tool▁sep|>" ,
541
+ " <|tool▁call▁end|>" ,
542
+ };
465
543
builder.add_rule (" root" , " \" <|tool▁calls▁begin|>\" (" + string_join (tool_rules, " | " ) + " )" + (inputs.parallel_tool_calls ? " *" : " " ) + " space" );
466
544
}, grammar_options);
467
545
data.prompt = tmpl.apply (inputs.messages , inputs.tools .empty () ? json () : inputs.tools , inputs.add_generation_prompt );
@@ -704,8 +782,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
704
782
auto tool_call = " \" <tool_call>\" space " + builder.add_rule (" tool_call" , string_join (tool_rules, " | " )) + " \" </tool_call>\" space" ;
705
783
builder.add_rule (" root" , inputs.parallel_tool_calls ? " (" + tool_call + " )+" : tool_call);
706
784
data.grammar_triggers .push_back ({" <tool_call>" , /* .at_start = */ false });
707
- // Not really a trigger but need to print this special token to get a successful parse.
708
- data.grammar_triggers .push_back ({" </tool_call>" , /* .at_start = */ false });
785
+ data.preserved_tokens = { " </tool_call>" };
709
786
}, grammar_options);
710
787
711
788
data.prompt = tmpl.apply (inputs.messages , inputs.tools .empty () ? json () : inputs.tools , inputs.add_generation_prompt );
@@ -822,6 +899,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
822
899
if (src.find (" [TOOL_CALLS]" ) != std::string::npos) {
823
900
return common_chat_params_init_mistral_nemo (tmpl, inputs);
824
901
}
902
+ if (src.find (" <|END_THINKING|><|START_ACTION|>" ) != std::string::npos) {
903
+ return common_chat_params_init_command_r7b (tmpl, inputs);
904
+ }
825
905
return common_chat_params_init_generic (tmpl, inputs);
826
906
}
827
907
@@ -855,6 +935,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
855
935
return common_chat_parse_hermes_2_pro (input);
856
936
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
857
937
return common_chat_parse_firefunction_v2 (input);
938
+ case COMMON_CHAT_FORMAT_COMMAND_R7B:
939
+ return common_chat_parse_command_r7b (input);
858
940
default :
859
941
throw std::runtime_error (" Unsupported format: " + common_chat_format_name (format));
860
942
}
0 commit comments