|
| 1 | +""" |
| 2 | +Unit tests for chat formatters |
| 3 | +""" |
| 4 | + |
| 5 | +# Third Party |
| 6 | +import pytest |
| 7 | + |
| 8 | +# Local |
| 9 | +from torchchat.generate import ( |
| 10 | + HFTokenizerChatFormatter, |
| 11 | + Llama2ChatFormatter, |
| 12 | + Llama3ChatFormatter, |
| 13 | +) |
| 14 | + |
| 15 | +## Helpers ##################################################################### |
| 16 | + |
| 17 | +class DummyTokenizer: |
| 18 | + """Dummy tokenizer that encodes as strings so it's easy to check formatting""" |
| 19 | + def encode(self, text, *_, **__): |
| 20 | + return text |
| 21 | + |
| 22 | + |
| 23 | +class DummySPTokenizer(DummyTokenizer): |
| 24 | + """Emulated Sentencepiece tokenizer with bos/eos""" |
| 25 | + bos = "<s>" |
| 26 | + eos = "</s>" |
| 27 | + |
| 28 | + |
| 29 | +class DummyLlama3Tokenizer(DummyTokenizer): |
| 30 | + class _IdentityDict: |
| 31 | + def __getitem__(self, key): |
| 32 | + return key |
| 33 | + special_tokens = _IdentityDict() |
| 34 | + |
| 35 | + |
| 36 | +class DummyHFTokenizer(DummyTokenizer): |
| 37 | + """Dummy made up chat template scheme""" |
| 38 | + # Sequence |
| 39 | + bos = "<bos>" |
| 40 | + # Turn |
| 41 | + bot = "<bot>" |
| 42 | + eot = "<eot>" |
| 43 | + # Role |
| 44 | + bor = "<bor>" |
| 45 | + eor = "<eor>" |
| 46 | + def apply_chat_template(self, messages, add_generation_prompt): |
| 47 | + out = [self.bos] |
| 48 | + role = None |
| 49 | + for msg in messages: |
| 50 | + role = msg["role"] |
| 51 | + content = msg["content"] |
| 52 | + out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}") |
| 53 | + if add_generation_prompt and role != "assistant": |
| 54 | + out.append(f"{self.bot}{self.bor}assistant{self.eor}") |
| 55 | + return "\n".join(out) |
| 56 | + |
| 57 | + |
| 58 | +def check_rendering(fmt, messages, expected, add_generation_prompt): |
| 59 | + """Render messages and compare to expected output""" |
| 60 | + assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected |
| 61 | + |
| 62 | + |
| 63 | +def make_message(role, text): |
| 64 | + return {"role": role, "content": text} |
| 65 | + |
| 66 | + |
| 67 | +SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything." |
| 68 | +USER1 = "Hello world!" |
| 69 | +ASSISTANT1 = "Greetings! How can I help you?" |
| 70 | +USER2 = "Why is the sky blue?" |
| 71 | +ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering." |
| 72 | + |
| 73 | + |
| 74 | +# Stock sets of messages to test |
| 75 | +MSGS_NO_SYS= [ |
| 76 | + make_message("user", USER1), |
| 77 | +] |
| 78 | +MSGS_SYS_USR = [ |
| 79 | + make_message("system", SYSTEM_PROMPT), |
| 80 | + make_message("user", USER1), |
| 81 | +] |
| 82 | +MSGS_SYS_USR_ASST = [ |
| 83 | + make_message("system", SYSTEM_PROMPT), |
| 84 | + make_message("user", USER1), |
| 85 | + make_message("assistant", ASSISTANT1), |
| 86 | +] |
| 87 | +MSGS_MULTI_TURN = [ |
| 88 | + make_message("system", SYSTEM_PROMPT), |
| 89 | + make_message("user", USER1), |
| 90 | + make_message("assistant", ASSISTANT1), |
| 91 | + make_message("user", USER2), |
| 92 | + make_message("assistant", ASSISTANT2), |
| 93 | +] |
| 94 | + |
| 95 | +## Llama2ChatFormatter ######################################################### |
| 96 | + |
| 97 | +@pytest.mark.parametrize( |
| 98 | + ["messages", "expected"], |
| 99 | + [ |
| 100 | + # single user message (no system prompt) |
| 101 | + (MSGS_NO_SYS, f"<s>[INST] {USER1} [/INST]"), |
| 102 | + # sys, usr |
| 103 | + (MSGS_SYS_USR, f"""<s>[INST] <<SYS>> |
| 104 | +{SYSTEM_PROMPT} |
| 105 | +<</SYS>> |
| 106 | +
|
| 107 | +{USER1} [/INST]"""), |
| 108 | + # sys, usr, asst |
| 109 | + (MSGS_SYS_USR_ASST, f"""<s>[INST] <<SYS>> |
| 110 | +{SYSTEM_PROMPT} |
| 111 | +<</SYS>> |
| 112 | +
|
| 113 | +{USER1} [/INST] {ASSISTANT1} </s> |
| 114 | +"""), |
| 115 | + # sys, usr, asst, usr, asst |
| 116 | + (MSGS_MULTI_TURN, f"""<s>[INST] <<SYS>> |
| 117 | +{SYSTEM_PROMPT} |
| 118 | +<</SYS>> |
| 119 | +
|
| 120 | +{USER1} [/INST] {ASSISTANT1} </s> |
| 121 | +<s>[INST] {USER2} [/INST] {ASSISTANT2} </s> |
| 122 | +"""), |
| 123 | + ] |
| 124 | +) |
| 125 | +def test_llama2_chat_formatter(messages, expected): |
| 126 | + """Tests for Llama2 following the official guide |
| 127 | + https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/ |
| 128 | + """ |
| 129 | + tok = DummySPTokenizer() |
| 130 | + fmt = Llama2ChatFormatter(tok) |
| 131 | + # NOTE: add_generation_prompt not used by Llama2 |
| 132 | + check_rendering(fmt, messages, expected, True) |
| 133 | + |
| 134 | +## Llama3ChatFormatter ######################################################### |
| 135 | + |
| 136 | +@pytest.mark.parametrize( |
| 137 | + ["messages", "expected"], |
| 138 | + [ |
| 139 | + # single user message (no system prompt) |
| 140 | + (MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
| 141 | +
|
| 142 | +{USER1}<|eot_id|>"""), |
| 143 | + # sys, usr |
| 144 | + (MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
| 145 | +
|
| 146 | +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> |
| 147 | +
|
| 148 | +{USER1}<|eot_id|>"""), |
| 149 | + # sys, usr, asst |
| 150 | + (MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
| 151 | +
|
| 152 | +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> |
| 153 | +
|
| 154 | +{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
| 155 | +
|
| 156 | +{ASSISTANT1}<|eot_id|>"""), |
| 157 | + # sys, usr, asst, usr, asst |
| 158 | + (MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
| 159 | +
|
| 160 | +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> |
| 161 | +
|
| 162 | +{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
| 163 | +
|
| 164 | +{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|> |
| 165 | +
|
| 166 | +{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
| 167 | +
|
| 168 | +{ASSISTANT2}<|eot_id|>"""), |
| 169 | + ] |
| 170 | +) |
| 171 | +@pytest.mark.parametrize("add_generation_prompt", [True, False]) |
| 172 | +def test_llama3_chat_formatter(messages, expected, add_generation_prompt): |
| 173 | + """Tests for Llama3 following the official guide |
| 174 | + https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/ |
| 175 | + """ |
| 176 | + tok = DummyLlama3Tokenizer() |
| 177 | + fmt = Llama3ChatFormatter(tok) |
| 178 | + # No assistant prompt added if the last message is from the assistant |
| 179 | + if add_generation_prompt and messages[-1]["role"] != "assistant": |
| 180 | + expected += "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| 181 | + check_rendering(fmt, messages, expected, add_generation_prompt) |
| 182 | + |
| 183 | +## HFTokenizerChatFormatter #################################################### |
| 184 | + |
| 185 | +@pytest.mark.parametrize( |
| 186 | + ["messages", "expected"], |
| 187 | + [ |
| 188 | + # single user message (no system prompt) |
| 189 | + (MSGS_NO_SYS, f"""<bos> |
| 190 | +<bot><bor>user<eor>{USER1}<eot>"""), |
| 191 | + # sys, usr |
| 192 | + (MSGS_SYS_USR, f"""<bos> |
| 193 | +<bot><bor>system<eor>{SYSTEM_PROMPT}<eot> |
| 194 | +<bot><bor>user<eor>{USER1}<eot>"""), |
| 195 | + # sys, usr, asst |
| 196 | + (MSGS_SYS_USR_ASST, f"""<bos> |
| 197 | +<bot><bor>system<eor>{SYSTEM_PROMPT}<eot> |
| 198 | +<bot><bor>user<eor>{USER1}<eot> |
| 199 | +<bot><bor>assistant<eor>{ASSISTANT1}<eot>"""), |
| 200 | + # sys, usr, asst, usr, asst |
| 201 | + (MSGS_MULTI_TURN, f"""<bos> |
| 202 | +<bot><bor>system<eor>{SYSTEM_PROMPT}<eot> |
| 203 | +<bot><bor>user<eor>{USER1}<eot> |
| 204 | +<bot><bor>assistant<eor>{ASSISTANT1}<eot> |
| 205 | +<bot><bor>user<eor>{USER2}<eot> |
| 206 | +<bot><bor>assistant<eor>{ASSISTANT2}<eot>"""), |
| 207 | + ] |
| 208 | +) |
| 209 | +@pytest.mark.parametrize("add_generation_prompt", [True, False]) |
| 210 | +def test_hf_chat_formatter(messages, expected, add_generation_prompt): |
| 211 | + tok = DummyHFTokenizer() |
| 212 | + fmt = HFTokenizerChatFormatter(tok) |
| 213 | + # No assistant prompt added if the last message is from the assistant |
| 214 | + if add_generation_prompt and messages[-1]["role"] != "assistant": |
| 215 | + expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}" |
| 216 | + check_rendering(fmt, messages, expected, add_generation_prompt) |
0 commit comments