Skip to content

Commit e5113e8

Browse files
committed
Add --jinja and --chat-template-file flags
1 parent abd274a commit e5113e8

File tree

12 files changed

+289
-50
lines changed

12 files changed

+289
-50
lines changed

Makefile

+2
Original file line numberDiff line numberDiff line change
@@ -1361,7 +1361,9 @@ llama-server: \
13611361
examples/server/httplib.h \
13621362
examples/server/index.html.hpp \
13631363
examples/server/loading.html.hpp \
1364+
common/chat-template.hpp \
13641365
common/json.hpp \
1366+
common/minja.hpp \
13651367
$(OBJ_ALL)
13661368
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
13671369
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)

common/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
5656
arg.cpp
5757
arg.h
5858
base64.hpp
59+
chat-template.hpp
5960
common.cpp
6061
common.h
6162
console.cpp
@@ -64,6 +65,7 @@ add_library(${TARGET} STATIC
6465
json.hpp
6566
log.cpp
6667
log.h
68+
minja.hpp
6769
ngram-cache.cpp
6870
ngram-cache.h
6971
sampling.cpp

common/arg.cpp

+39-4
Original file line numberDiff line numberDiff line change
@@ -1889,24 +1889,59 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
18891889
}
18901890
}
18911891
).set_examples({LLAMA_EXAMPLE_SERVER}));
1892+
add_opt(common_arg(
1893+
{"--jinja"},
1894+
"use jinja template for chat (default: disabled)",
1895+
[](common_params & params) {
1896+
params.use_jinja = true;
1897+
}
1898+
).set_examples({LLAMA_EXAMPLE_SERVER}));
18921899
add_opt(common_arg(
18931900
{"--chat-template"}, "JINJA_TEMPLATE",
18941901
string_format(
18951902
"set custom jinja chat template (default: template taken from model's metadata)\n"
18961903
"if suffix/prefix are specified, template will be disabled\n"
1904+
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
18971905
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
18981906
),
18991907
[](common_params & params, const std::string & value) {
1900-
if (!common_chat_verify_template(value)) {
1908+
if (!common_chat_verify_template(value, params.use_jinja)) {
19011909
throw std::runtime_error(string_format(
1902-
"error: the supplied chat template is not supported: %s\n"
1903-
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
1904-
value.c_str()
1910+
"error: the supplied chat template is not supported: %s%s\n",
1911+
value.c_str(),
1912+
params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates"
19051913
));
19061914
}
19071915
params.chat_template = value;
19081916
}
19091917
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
1918+
add_opt(common_arg(
1919+
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
1920+
"set custom jinja chat template file (default: template taken from model's metadata)\n"
1921+
"if suffix/prefix are specified, template will be disabled\n"
1922+
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
1923+
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template",
1924+
[](common_params & params, const std::string & value) {
1925+
std::ifstream file(value);
1926+
if (!file) {
1927+
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
1928+
}
1929+
std::string chat_template;
1930+
std::copy(
1931+
std::istreambuf_iterator<char>(file),
1932+
std::istreambuf_iterator<char>(),
1933+
std::back_inserter(chat_template)
1934+
);
1935+
if (!common_chat_verify_template(chat_template, params.use_jinja)) {
1936+
throw std::runtime_error(string_format(
1937+
"error: the supplied chat template is not supported: %s%s\n",
1938+
value.c_str(),
1939+
params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates"
1940+
));
1941+
}
1942+
params.chat_template = chat_template;
1943+
}
1944+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
19101945
add_opt(common_arg(
19111946
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
19121947
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

common/common.cpp

+64-4
Original file line numberDiff line numberDiff line change
@@ -1576,13 +1576,13 @@ std::vector<llama_token> common_tokenize(
15761576
return result;
15771577
}
15781578

1579-
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1579+
static std::string _common_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
15801580
std::string piece;
15811581
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
1582-
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
1582+
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
15831583
if (n_chars < 0) {
15841584
piece.resize(-n_chars);
1585-
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
1585+
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
15861586
GGML_ASSERT(check == -n_chars);
15871587
}
15881588
else {
@@ -1592,6 +1592,10 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token
15921592
return piece;
15931593
}
15941594

1595+
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1596+
return _common_token_to_piece(llama_get_model(ctx), token, special);
1597+
}
1598+
15951599
std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
15961600
std::string text;
15971601
text.resize(std::max(text.capacity(), tokens.size()));
@@ -1612,7 +1616,21 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
16121616
// Chat template utils
16131617
//
16141618

1615-
bool common_chat_verify_template(const std::string & tmpl) {
1619+
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
1620+
if (use_jinja) {
1621+
try {
1622+
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
1623+
chat_template.apply({{
1624+
{"role", "user"},
1625+
{"content", "test"},
1626+
}}, json(), true);
1627+
return true;
1628+
} catch (const std::exception & e) {
1629+
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
1630+
return false;
1631+
}
1632+
}
1633+
16161634
llama_chat_message chat[] = {{"user", "test"}};
16171635
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
16181636
return res >= 0;
@@ -1693,6 +1711,48 @@ std::string common_chat_format_example(const struct llama_model * model,
16931711
return common_chat_apply_template(model, tmpl, msgs, true);
16941712
}
16951713

1714+
static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
1715+
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
1716+
if (tlen > 0) {
1717+
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
1718+
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
1719+
return std::string(curr_tmpl_buf.data(), tlen);
1720+
}
1721+
}
1722+
return "";
1723+
}
1724+
1725+
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
1726+
{
1727+
auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true);
1728+
auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true);
1729+
std::string default_template_src = chat_template_override;
1730+
std::string tool_use_template_src = chat_template_override;
1731+
if (chat_template_override.empty()) {
1732+
default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template");
1733+
tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
1734+
}
1735+
if (default_template_src.empty() || default_template_src == "chatml") {
1736+
if (!tool_use_template_src.empty()) {
1737+
default_template_src = tool_use_template_src;
1738+
} else {
1739+
default_template_src = R"(
1740+
{%- for message in messages -%}
1741+
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
1742+
{%- endfor -%}
1743+
{%- if add_generation_prompt -%}
1744+
{{- "<|im_start|>assistant\n" -}}
1745+
{%- endif -%}
1746+
)";
1747+
}
1748+
}
1749+
return {
1750+
.default_template = { default_template_src, bos_token, eos_token },
1751+
.tool_use_template = tool_use_template_src.empty() ? std::nullopt
1752+
: std::optional<minja::chat_template>({ tool_use_template_src, bos_token, eos_token }),
1753+
};
1754+
}
1755+
16961756
//
16971757
// KV cache utils
16981758
//

common/common.h

+12-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include "llama.h"
6+
#include "chat-template.hpp"
67

78
#include <string>
89
#include <vector>
@@ -324,6 +325,7 @@ struct common_params {
324325
std::string hostname = "127.0.0.1";
325326
std::string public_path = ""; // NOLINT
326327
std::string chat_template = ""; // NOLINT
328+
bool use_jinja = false; // NOLINT
327329
bool enable_chat_template = true;
328330

329331
std::vector<std::string> api_keys;
@@ -571,8 +573,8 @@ struct common_chat_msg {
571573
std::string content;
572574
};
573575

574-
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
575-
bool common_chat_verify_template(const std::string & tmpl);
576+
// Check if the template is supported or not. Returns true if it's valid
577+
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
576578

577579
// CPP wrapper for llama_chat_apply_template
578580
// If the built-in template is not supported, we default to chatml
@@ -593,6 +595,14 @@ std::string common_chat_format_single(const struct llama_model * model,
593595
std::string common_chat_format_example(const struct llama_model * model,
594596
const std::string & tmpl);
595597

598+
599+
struct llama_chat_templates {
600+
minja::chat_template default_template;
601+
std::optional<minja::chat_template> tool_use_template;
602+
};
603+
604+
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
605+
596606
//
597607
// KV cache utils
598608
//

examples/server/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ The project is under active development, and we are [looking for feedback and co
129129
| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') |
130130
| `--grammar-file FNAME` | file to read grammar from |
131131
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
132-
132+
| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) |
133133

134134
**Example-specific params**
135135

examples/server/server.cpp

+53-14
Original file line numberDiff line numberDiff line change
@@ -1623,15 +1623,35 @@ struct server_context {
16231623
return true;
16241624
}
16251625

1626-
bool validate_model_chat_template() const {
1627-
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
1628-
std::string template_key = "tokenizer.chat_template";
1629-
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
1630-
if (res >= 0) {
1631-
llama_chat_message chat[] = {{"user", "test"}};
1632-
std::string tmpl = std::string(model_template.data(), model_template.size());
1633-
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
1634-
return chat_res > 0;
1626+
bool validate_model_chat_template(bool use_jinja) const {
1627+
llama_chat_message chat[] = {{"user", "test"}};
1628+
1629+
if (use_jinja) {
1630+
auto templates = llama_chat_templates_from_model(model, "");
1631+
try {
1632+
templates.default_template.apply({{
1633+
{"role", "user"},
1634+
{"content", "test"},
1635+
}}, json(), true);
1636+
if (templates.tool_use_template) {
1637+
templates.tool_use_template->apply({{
1638+
{"role", "user"},
1639+
{"content", "test"},
1640+
}}, json(), true);
1641+
}
1642+
return true;
1643+
} catch (const std::exception & e) {
1644+
SRV_ERR("failed to apply template: %s\n", e.what());
1645+
}
1646+
} else {
1647+
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
1648+
std::string template_key = "tokenizer.chat_template";
1649+
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
1650+
if (res >= 0) {
1651+
std::string tmpl = std::string(model_template.data(), model_template.size());
1652+
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
1653+
return chat_res > 0;
1654+
}
16351655
}
16361656
return false;
16371657
}
@@ -3476,15 +3496,30 @@ int main(int argc, char ** argv) {
34763496
}
34773497
};
34783498

3479-
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
3499+
std::mutex chat_templates_mutex;
3500+
std::optional<llama_chat_templates> chat_templates;
3501+
3502+
auto get_chat_templates = [&ctx_server, &chat_templates_mutex, &chat_templates]() -> const llama_chat_templates & {
3503+
std::lock_guard<std::mutex> lock(chat_templates_mutex);
3504+
if (!chat_templates) {
3505+
chat_templates = llama_chat_templates_from_model(ctx_server.model, ctx_server.params_base.chat_template);
3506+
}
3507+
return *chat_templates;
3508+
};
3509+
3510+
const auto handle_props = [&ctx_server, &res_ok, &get_chat_templates](const httplib::Request &, httplib::Response & res) {
34803511
// this endpoint is publicly available, please only return what is safe to be exposed
3512+
const auto & templates = get_chat_templates();
34813513
json data = {
34823514
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
34833515
{ "total_slots", ctx_server.params_base.n_parallel },
34843516
{ "model_path", ctx_server.params_base.model },
3485-
{ "chat_template", llama_get_chat_template(ctx_server.model) },
3517+
{ "chat_template", templates.default_template.source() },
34863518
{ "build_info", build_info },
34873519
};
3520+
if (ctx_server.params_base.use_jinja && templates.tool_use_template) {
3521+
data["chat_template_tool_use"] = templates.tool_use_template->source();
3522+
}
34883523

34893524
res_ok(res, data);
34903525
};
@@ -3685,13 +3720,17 @@ int main(int argc, char ** argv) {
36853720
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
36863721
};
36873722

3688-
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
3723+
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic, &get_chat_templates](const httplib::Request & req, httplib::Response & res) {
36893724
if (ctx_server.params_base.embedding) {
36903725
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
36913726
return;
36923727
}
36933728

3694-
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
3729+
auto body = json::parse(req.body);
3730+
const auto & templates = get_chat_templates();
3731+
const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template;
3732+
json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, params.use_jinja);
3733+
36953734
return handle_completions_generic(
36963735
SERVER_TASK_TYPE_COMPLETION,
36973736
data,
@@ -4111,7 +4150,7 @@ int main(int argc, char ** argv) {
41114150

41124151
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
41134152
if (params.chat_template.empty()) {
4114-
if (!ctx_server.validate_model_chat_template()) {
4153+
if (!ctx_server.validate_model_chat_template(params.use_jinja)) {
41154154
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
41164155
params.chat_template = "chatml";
41174156
}

examples/server/tests/unit/test_chat_completion.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,24 @@
44

55
server = ServerPreset.tinyllama2()
66

7-
8-
@pytest.fixture(scope="module", autouse=True)
7+
@pytest.fixture(autouse=True)
98
def create_server():
109
global server
1110
server = ServerPreset.tinyllama2()
1211

1312

1413
@pytest.mark.parametrize(
15-
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
14+
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja",
1615
[
17-
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
18-
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
16+
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False),
17+
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True),
18+
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False),
19+
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True),
1920
]
2021
)
21-
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
22+
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja):
2223
global server
24+
server.jinja = jinja
2325
server.start()
2426
res = server.make_request("POST", "/chat/completions", data={
2527
"model": model,
@@ -102,6 +104,7 @@ def test_chat_completion_with_openai_library():
102104

103105
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
104106
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
107+
({"type": "json_schema", "json_schema": {"const": "42"}}, 6, "\"42\""),
105108
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
106109
({"type": "json_object"}, 10, "(\\{|John)+"),
107110
({"type": "sound"}, 0, None),

0 commit comments

Comments
 (0)