@@ -95,6 +95,67 @@ def _test_batch_generation(self, engine):
95
95
print ("-" * 40 )
96
96
97
97
98
+ class TestEAGLEEngineTokenMap (unittest .TestCase ):
99
+ BASE_CONFIG = {
100
+ "model_path" : "meta-llama/Meta-Llama-3-8B-Instruct" ,
101
+ "speculative_draft_model_path" : "lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B" ,
102
+ "speculative_algorithm" : "EAGLE" ,
103
+ "speculative_num_steps" : 5 ,
104
+ "speculative_eagle_topk" : 8 ,
105
+ "speculative_num_draft_tokens" : 64 ,
106
+ "mem_fraction_static" : 0.7 ,
107
+ "cuda_graph_max_bs" : 4 ,
108
+ "dtype" : "float16" ,
109
+ }
110
+
111
+ def setUp (self ):
112
+ self .prompt = "Today is a sunny day and I like"
113
+ self .sampling_params = {"temperature" : 0 , "max_new_tokens" : 8 }
114
+
115
+ ref_engine = sgl .Engine (model_path = self .BASE_CONFIG ["model_path" ])
116
+ self .ref_output = ref_engine .generate (self .prompt , self .sampling_params )["text" ]
117
+ ref_engine .shutdown ()
118
+
119
+ def test_token_map_accuracy (self ):
120
+ configs = [
121
+ self .BASE_CONFIG ,
122
+ {
123
+ ** self .BASE_CONFIG ,
124
+ "speculative_token_map" : "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt" ,
125
+ },
126
+ ]
127
+
128
+ for config in configs :
129
+ print ("testing config: " , config )
130
+ with self .subTest (cuda_graph = "enabled" ):
131
+ engine = sgl .Engine (** config )
132
+ try :
133
+ self ._test_basic_generation (engine )
134
+ self ._test_batch_generation (engine )
135
+ finally :
136
+ engine .shutdown ()
137
+
138
+ def _test_basic_generation (self , engine ):
139
+ output = engine .generate (self .prompt , self .sampling_params )["text" ]
140
+ print (f"{ output = } , { self .ref_output = } " )
141
+ self .assertEqual (output , self .ref_output )
142
+
143
+ def _test_batch_generation (self , engine ):
144
+ prompts = [
145
+ "Hello, my name is" ,
146
+ "The president of the United States is" ,
147
+ "The capital of France is" ,
148
+ "The future of AI is" ,
149
+ ]
150
+ params = {"temperature" : 0 , "max_new_tokens" : 30 }
151
+
152
+ outputs = engine .generate (prompts , params )
153
+ for prompt , output in zip (prompts , outputs ):
154
+ print (f"Prompt: { prompt } " )
155
+ print (f"Generated: { output ['text' ]} " )
156
+ print ("-" * 40 )
157
+
158
+
98
159
prompts = [
99
160
"[INST] <<SYS>>\\ nYou are a helpful assistant.\\ n<</SYS>>\\ nToday is a sunny day and I like[/INST]"
100
161
'[INST] <<SYS>>\\ nYou are a helpful assistant.\\ n<</SYS>>\\ nWhat are the mental triggers in Jeff Walker\' s Product Launch Formula and "Launch" book?[/INST]' ,
0 commit comments