Skip to content

Commit 0d4e322

Browse files
authored
[Feature] Add test for speculative_token_map (sgl-project#4016)
1 parent 926f8ef commit 0d4e322

File tree

2 files changed

+75
-14
lines changed

2 files changed

+75
-14
lines changed

python/sglang/srt/speculative/eagle_worker.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@
3131
logger = logging.getLogger(__name__)
3232

3333

34+
def load_token_map(token_map_path: str) -> List[int]:
35+
if not os.path.exists(token_map_path):
36+
cache_dir = snapshot_download(
37+
os.path.dirname(token_map_path),
38+
ignore_patterns=["*.bin", "*.safetensors"],
39+
)
40+
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
41+
return torch.load(token_map_path)
42+
43+
3444
class EAGLEWorker(TpModelWorker):
3545

3646
def __init__(
@@ -48,20 +58,12 @@ def __init__(
4858
server_args.disable_cuda_graph = True
4959

5060
if server_args.speculative_token_map is not None:
51-
if os.path.exists(server_args.speculative_token_map):
52-
self.hot_token_id = torch.load(server_args.speculative_token_map)
53-
else:
54-
cache_dir = snapshot_download(
55-
os.path.dirname(server_args.speculative_token_map),
56-
ignore_patterns=["*.bin", "*.safetensors"],
57-
)
58-
file_path = os.path.join(
59-
cache_dir, os.path.basename(server_args.speculative_token_map)
60-
)
61-
self.hot_token_id = torch.load(file_path)
61+
self.hot_token_id = load_token_map(server_args.speculative_token_map)
6262
server_args.json_model_override_args = (
6363
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
6464
)
65+
else:
66+
self.hot_token_id = None
6567

6668
super().__init__(
6769
gpu_id=gpu_id,
@@ -84,14 +86,12 @@ def __init__(
8486

8587
# Share the embedding and lm_head
8688
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
87-
if server_args.speculative_token_map is not None:
89+
if self.hot_token_id is not None:
8890
head = head.clone()
8991
self.hot_token_id = torch.tensor(
9092
self.hot_token_id, dtype=torch.int32, device=head.device
9193
)
9294
head.data = head.data[self.hot_token_id]
93-
else:
94-
self.hot_token_id = None
9595
self.model_runner.model.set_embed_and_head(embed, head)
9696
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
9797

test/srt/test_eagle_infer.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,67 @@ def _test_batch_generation(self, engine):
9595
print("-" * 40)
9696

9797

98+
class TestEAGLEEngineTokenMap(unittest.TestCase):
99+
BASE_CONFIG = {
100+
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
101+
"speculative_draft_model_path": "lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B",
102+
"speculative_algorithm": "EAGLE",
103+
"speculative_num_steps": 5,
104+
"speculative_eagle_topk": 8,
105+
"speculative_num_draft_tokens": 64,
106+
"mem_fraction_static": 0.7,
107+
"cuda_graph_max_bs": 4,
108+
"dtype": "float16",
109+
}
110+
111+
def setUp(self):
112+
self.prompt = "Today is a sunny day and I like"
113+
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
114+
115+
ref_engine = sgl.Engine(model_path=self.BASE_CONFIG["model_path"])
116+
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
117+
ref_engine.shutdown()
118+
119+
def test_token_map_accuracy(self):
120+
configs = [
121+
self.BASE_CONFIG,
122+
{
123+
**self.BASE_CONFIG,
124+
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
125+
},
126+
]
127+
128+
for config in configs:
129+
print("testing config: ", config)
130+
with self.subTest(cuda_graph="enabled"):
131+
engine = sgl.Engine(**config)
132+
try:
133+
self._test_basic_generation(engine)
134+
self._test_batch_generation(engine)
135+
finally:
136+
engine.shutdown()
137+
138+
def _test_basic_generation(self, engine):
139+
output = engine.generate(self.prompt, self.sampling_params)["text"]
140+
print(f"{output=}, {self.ref_output=}")
141+
self.assertEqual(output, self.ref_output)
142+
143+
def _test_batch_generation(self, engine):
144+
prompts = [
145+
"Hello, my name is",
146+
"The president of the United States is",
147+
"The capital of France is",
148+
"The future of AI is",
149+
]
150+
params = {"temperature": 0, "max_new_tokens": 30}
151+
152+
outputs = engine.generate(prompts, params)
153+
for prompt, output in zip(prompts, outputs):
154+
print(f"Prompt: {prompt}")
155+
print(f"Generated: {output['text']}")
156+
print("-" * 40)
157+
158+
98159
prompts = [
99160
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
100161
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',

0 commit comments

Comments
 (0)