Skip to content

Commit 90749d2

Browse files
Add Granite code support (#1336)
* feat(models): Add models.json blocks for Granite Code 3b and 8b Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * feat: Initial model params for granite code 3b Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(model config): Fix model configs for Granite Code * Use the right tokenizer_file name * Use the right transformer_params_key based on the file name in model_params * Use the updated name to indicate HF tokenizers Signed-off-by: Gabe Goodhart <[email protected]> * feat(granite): Add model params for granite-code-8b Something isn't quite working with this model yet, but the config should be accurate at this point. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(deps): Add tokenizers to the deps explicitly It was implicitly being pulled in via lm_eval -> transformers, but it's better to have it explicit since we use it directly Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * feat(tokenizer): Add basic support for jinja2 template rendering for HF tokenizers This is a much simplified version of the corresponding logic in transformers. I opted for this so that the full transformers dependency is not added here. CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1522 Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(chat): Add HFTokenizerChatFormatter and use it for HF tokenizers This will allow the jinja2 templates for HF tokenizers to be applied without needing to hard-code the formatter logic. This will likely need to be duplicated in the embedded code version of chat. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(deps): Add jinja2 as an explicit dep It was getting pulled in implicitly via flask and lm_eval -> transformers, but better to have it explicit. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * feat(log): Add env-based LOG_LEVEL config to CLI Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * feat(log): Add better logging in model and generate In generate, there were a number of commented-out log lines. These are safe to leave in as long as lazy string interpolation is used. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * feat(generate): Make prepending BOS model-conigurable And disable it for Granite Code models Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(chat): Refactor chat template logic to encapsulate all formatting in classes The formatted strings may not be perfectly 1:1 with the previous impl, but they should be in line with the official model guidelines: * https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3 * https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2 Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(chat): Fix small formatting bugs in llama3 chat formatter Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * test: Add initial unit tests for chat formatters There's no formal execution framework for pytest yet, but these were helpful in ensuring that the formatting was working correctly! To run them, install pytest and run `pytest tests/` Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix(logging): Disable logging in generate unless set in the env There is an incompatibility with logging and torch._dynamo, so this disables it unless the developer asks for it explicitly. NOTE: The TC team has stated that they have holistic logging on the roadmap so this is a short-term solution pending a more robust approach. REF: https://github.com/pytorch/torchchat/actions/runs/11963066986/job/33493237302#step:14:3599 Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> * fix: Remove trailing n from llama3 <|eot_id|> There's inconsistency in the documentation on whether or not there should be a n after <|eot_id|>, but this maintains consistency with previous formatting Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
1 parent fd1857a commit 90749d2

File tree

10 files changed

+469
-75
lines changed

10 files changed

+469
-75
lines changed

install/requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ gguf
99
# Tiktoken tokenizer for Llama 3 and other advanced models
1010
tiktoken
1111

12+
# Tokenizers and jinja2 for other non-llama models that use HF tokenizers
13+
tokenizers
14+
jinja2
15+
1216
# Miscellaneous
1317
snakeviz
1418
sentencepiece

tests/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
Global pytest config, fixtures, and helpers go here!
3+
"""
4+
5+
# Standard
6+
import os
7+
import sys
8+
9+
# Make sure tests can import torchchat
10+
sys.path.append(
11+
os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
12+
)

tests/test_chat_formatters.py

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""
2+
Unit tests for chat formatters
3+
"""
4+
5+
# Third Party
6+
import pytest
7+
8+
# Local
9+
from torchchat.generate import (
10+
HFTokenizerChatFormatter,
11+
Llama2ChatFormatter,
12+
Llama3ChatFormatter,
13+
)
14+
15+
## Helpers #####################################################################
16+
17+
class DummyTokenizer:
18+
"""Dummy tokenizer that encodes as strings so it's easy to check formatting"""
19+
def encode(self, text, *_, **__):
20+
return text
21+
22+
23+
class DummySPTokenizer(DummyTokenizer):
24+
"""Emulated Sentencepiece tokenizer with bos/eos"""
25+
bos = "<s>"
26+
eos = "</s>"
27+
28+
29+
class DummyLlama3Tokenizer(DummyTokenizer):
30+
class _IdentityDict:
31+
def __getitem__(self, key):
32+
return key
33+
special_tokens = _IdentityDict()
34+
35+
36+
class DummyHFTokenizer(DummyTokenizer):
37+
"""Dummy made up chat template scheme"""
38+
# Sequence
39+
bos = "<bos>"
40+
# Turn
41+
bot = "<bot>"
42+
eot = "<eot>"
43+
# Role
44+
bor = "<bor>"
45+
eor = "<eor>"
46+
def apply_chat_template(self, messages, add_generation_prompt):
47+
out = [self.bos]
48+
role = None
49+
for msg in messages:
50+
role = msg["role"]
51+
content = msg["content"]
52+
out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}")
53+
if add_generation_prompt and role != "assistant":
54+
out.append(f"{self.bot}{self.bor}assistant{self.eor}")
55+
return "\n".join(out)
56+
57+
58+
def check_rendering(fmt, messages, expected, add_generation_prompt):
59+
"""Render messages and compare to expected output"""
60+
assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected
61+
62+
63+
def make_message(role, text):
64+
return {"role": role, "content": text}
65+
66+
67+
SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything."
68+
USER1 = "Hello world!"
69+
ASSISTANT1 = "Greetings! How can I help you?"
70+
USER2 = "Why is the sky blue?"
71+
ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering."
72+
73+
74+
# Stock sets of messages to test
75+
MSGS_NO_SYS= [
76+
make_message("user", USER1),
77+
]
78+
MSGS_SYS_USR = [
79+
make_message("system", SYSTEM_PROMPT),
80+
make_message("user", USER1),
81+
]
82+
MSGS_SYS_USR_ASST = [
83+
make_message("system", SYSTEM_PROMPT),
84+
make_message("user", USER1),
85+
make_message("assistant", ASSISTANT1),
86+
]
87+
MSGS_MULTI_TURN = [
88+
make_message("system", SYSTEM_PROMPT),
89+
make_message("user", USER1),
90+
make_message("assistant", ASSISTANT1),
91+
make_message("user", USER2),
92+
make_message("assistant", ASSISTANT2),
93+
]
94+
95+
## Llama2ChatFormatter #########################################################
96+
97+
@pytest.mark.parametrize(
98+
["messages", "expected"],
99+
[
100+
# single user message (no system prompt)
101+
(MSGS_NO_SYS, f"<s>[INST] {USER1} [/INST]"),
102+
# sys, usr
103+
(MSGS_SYS_USR, f"""<s>[INST] <<SYS>>
104+
{SYSTEM_PROMPT}
105+
<</SYS>>
106+
107+
{USER1} [/INST]"""),
108+
# sys, usr, asst
109+
(MSGS_SYS_USR_ASST, f"""<s>[INST] <<SYS>>
110+
{SYSTEM_PROMPT}
111+
<</SYS>>
112+
113+
{USER1} [/INST] {ASSISTANT1} </s>
114+
"""),
115+
# sys, usr, asst, usr, asst
116+
(MSGS_MULTI_TURN, f"""<s>[INST] <<SYS>>
117+
{SYSTEM_PROMPT}
118+
<</SYS>>
119+
120+
{USER1} [/INST] {ASSISTANT1} </s>
121+
<s>[INST] {USER2} [/INST] {ASSISTANT2} </s>
122+
"""),
123+
]
124+
)
125+
def test_llama2_chat_formatter(messages, expected):
126+
"""Tests for Llama2 following the official guide
127+
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
128+
"""
129+
tok = DummySPTokenizer()
130+
fmt = Llama2ChatFormatter(tok)
131+
# NOTE: add_generation_prompt not used by Llama2
132+
check_rendering(fmt, messages, expected, True)
133+
134+
## Llama3ChatFormatter #########################################################
135+
136+
@pytest.mark.parametrize(
137+
["messages", "expected"],
138+
[
139+
# single user message (no system prompt)
140+
(MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
141+
142+
{USER1}<|eot_id|>"""),
143+
# sys, usr
144+
(MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
145+
146+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
147+
148+
{USER1}<|eot_id|>"""),
149+
# sys, usr, asst
150+
(MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
151+
152+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
153+
154+
{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
155+
156+
{ASSISTANT1}<|eot_id|>"""),
157+
# sys, usr, asst, usr, asst
158+
(MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
159+
160+
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
161+
162+
{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
163+
164+
{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|>
165+
166+
{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
167+
168+
{ASSISTANT2}<|eot_id|>"""),
169+
]
170+
)
171+
@pytest.mark.parametrize("add_generation_prompt", [True, False])
172+
def test_llama3_chat_formatter(messages, expected, add_generation_prompt):
173+
"""Tests for Llama3 following the official guide
174+
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/
175+
"""
176+
tok = DummyLlama3Tokenizer()
177+
fmt = Llama3ChatFormatter(tok)
178+
# No assistant prompt added if the last message is from the assistant
179+
if add_generation_prompt and messages[-1]["role"] != "assistant":
180+
expected += "<|start_header_id|>assistant<|end_header_id|>\n\n"
181+
check_rendering(fmt, messages, expected, add_generation_prompt)
182+
183+
## HFTokenizerChatFormatter ####################################################
184+
185+
@pytest.mark.parametrize(
186+
["messages", "expected"],
187+
[
188+
# single user message (no system prompt)
189+
(MSGS_NO_SYS, f"""<bos>
190+
<bot><bor>user<eor>{USER1}<eot>"""),
191+
# sys, usr
192+
(MSGS_SYS_USR, f"""<bos>
193+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
194+
<bot><bor>user<eor>{USER1}<eot>"""),
195+
# sys, usr, asst
196+
(MSGS_SYS_USR_ASST, f"""<bos>
197+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
198+
<bot><bor>user<eor>{USER1}<eot>
199+
<bot><bor>assistant<eor>{ASSISTANT1}<eot>"""),
200+
# sys, usr, asst, usr, asst
201+
(MSGS_MULTI_TURN, f"""<bos>
202+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
203+
<bot><bor>user<eor>{USER1}<eot>
204+
<bot><bor>assistant<eor>{ASSISTANT1}<eot>
205+
<bot><bor>user<eor>{USER2}<eot>
206+
<bot><bor>assistant<eor>{ASSISTANT2}<eot>"""),
207+
]
208+
)
209+
@pytest.mark.parametrize("add_generation_prompt", [True, False])
210+
def test_hf_chat_formatter(messages, expected, add_generation_prompt):
211+
tok = DummyHFTokenizer()
212+
fmt = HFTokenizerChatFormatter(tok)
213+
# No assistant prompt added if the last message is from the assistant
214+
if add_generation_prompt and messages[-1]["role"] != "assistant":
215+
expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}"
216+
check_rendering(fmt, messages, expected, add_generation_prompt)

tokenizer/hf_tokenizer.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# Standard
8-
from typing import List, Optional
8+
from typing import Dict, List, Optional
99
import json
1010
import os
1111

1212
# Third Party
13+
import jinja2
1314
from tokenizers import Tokenizer
1415

1516
# Local
@@ -37,6 +38,9 @@ def __init__(self, file_path: str):
3738
# Load the tokenizer itself
3839
self._tokenizer = Tokenizer.from_file(tokenizer_path)
3940

41+
# Load the chat template if we have a config path
42+
self._chat_template: Optional[jinja2.Template] = None
43+
4044
# If available, parse bos/eos tokens from the tokenizer config
4145
self._bos_id, self._eos_id = None, None
4246
if tokenizer_config_path is not None:
@@ -48,6 +52,8 @@ def __init__(self, file_path: str):
4852
self._bos_id = self._tokenizer.token_to_id(bos_token)
4953
if eos_token is not None:
5054
self._eos_id = self._tokenizer.token_to_id(eos_token)
55+
if chat_template_str := tok_config.get("chat_template"):
56+
self._chat_template = jinja2.Template(chat_template_str)
5157

5258
# If no eos/bos tokens found, go looking for them!
5359
if None in [self._bos_id, self._eos_id]:
@@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio
7076
if len(candidate_toks) == 1:
7177
return candidate_toks[0]["id"]
7278

79+
## Interface ##
80+
7381
def encode(
7482
self,
7583
s: str,
@@ -90,3 +98,21 @@ def bos_id(self) -> int:
9098

9199
def eos_id(self) -> int:
92100
return self._eos_id
101+
102+
## Additional Public Methods ##
103+
104+
def has_chat_template(self) -> bool:
105+
return bool(self._chat_template)
106+
107+
def apply_chat_template(
108+
self,
109+
dialog: List[Dict[str, str]],
110+
add_generation_prompt: bool = False,
111+
) -> str:
112+
"""If configured with a chat template, apply it to the list of messages
113+
"""
114+
if not self._chat_template:
115+
raise ValueError("No chat template configured!")
116+
return self._chat_template.render(
117+
messages=dialog, add_generation_prompt=add_generation_prompt
118+
)

torchchat/cli/cli.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,15 @@
1717
allowable_params_table,
1818
)
1919

20-
logging.basicConfig(level=logging.INFO, format="%(message)s")
20+
_log_level_env = os.getenv("LOG_LEVEL", "INFO")
21+
try:
22+
_log_level = getattr(logging, _log_level_env.upper())
23+
except AttributeError:
24+
print(f"Invalid log level: {_log_level_env}", file=sys.stderr)
25+
_log_level = logging.INFO
26+
27+
28+
logging.basicConfig(level=_log_level, format="%(message)s")
2129
logger = logging.getLogger(__name__)
2230

2331
default_device = os.getenv("TORCHCHAT_DEVICE", "fast")

0 commit comments

Comments
 (0)