Skip to content

Commit 6626fa1

Browse files
Correct bs to batch_size
1 parent 2c6b191 commit 6626fa1

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

sharktank/tests/evaluate/perplexity_vmfb_test.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ def test_llama3_8B_f16_decomposed(self):
5454
f"--iree-hip-target={self.iree_hip_target}",
5555
f"--tensor-parallelism-size=1",
5656
f"--attention-kernel=decomposed",
57-
f"--num-prompts={self.bs}",
57+
f"--num-prompts={self.batch_size}",
5858
]
5959
)
6060

6161
baseline_mean_perplexity = round(
62-
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
62+
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
6363
)
6464
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
6565

@@ -73,7 +73,7 @@ def test_llama3_8B_f16_decomposed(self):
7373
)
7474

7575
@skipif_run_quick_llama_test
76-
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
76+
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
7777
def test_llama3_8B_f16(self):
7878

7979
# Llama 3.1 8B non-decomposed
@@ -90,12 +90,12 @@ def test_llama3_8B_f16(self):
9090
f"--iree-hip-target={self.iree_hip_target}",
9191
f"--tensor-parallelism-size=1",
9292
f"--attention-kernel=torch_sdpa",
93-
f"--num-prompts={self.bs}",
93+
f"--num-prompts={self.batch_size}",
9494
]
9595
)
9696

9797
baseline_mean_perplexity = round(
98-
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
98+
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
9999
)
100100
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
101101

@@ -109,7 +109,7 @@ def test_llama3_8B_f16(self):
109109
)
110110

111111
@skipif_run_quick_llama_test
112-
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
112+
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
113113
def test_llama3_8B_fp8_decomposed(self):
114114

115115
# Llama 3.1 8B decomposed
@@ -126,12 +126,12 @@ def test_llama3_8B_fp8_decomposed(self):
126126
f"--iree-hip-target={self.iree_hip_target}",
127127
f"--tensor-parallelism-size=1",
128128
f"--attention-kernel=decomposed",
129-
f"--num-prompts={self.bs}",
129+
f"--num-prompts={self.batch_size}",
130130
]
131131
)
132132

133133
baseline_mean_perplexity = round(
134-
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
134+
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
135135
)
136136
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
137137

@@ -145,7 +145,7 @@ def test_llama3_8B_fp8_decomposed(self):
145145
)
146146

147147
@skipif_run_quick_llama_test
148-
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
148+
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
149149
def test_llama3_8B_fp8(self):
150150

151151
# Llama 3.1 8B non-decomposed
@@ -162,12 +162,12 @@ def test_llama3_8B_fp8(self):
162162
f"--iree-hip-target={self.iree_hip_target}",
163163
f"--tensor-parallelism-size=1",
164164
f"--attention-kernel=torch_sdpa",
165-
f"--num-prompts={self.bs}",
165+
f"--num-prompts={self.batch_size}",
166166
]
167167
)
168168

169169
baseline_mean_perplexity = round(
170-
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
170+
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
171171
)
172172
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
173173

@@ -200,12 +200,12 @@ def test_llama3_405B_f16_decomposed(self):
200200
f"--iree-hip-target={self.iree_hip_target}",
201201
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
202202
f"--attention-kernel=decomposed",
203-
f"--num-prompts={self.bs}",
203+
f"--num-prompts={self.batch_size}",
204204
]
205205
)
206206

207207
baseline_mean_perplexity = round(
208-
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
208+
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
209209
)
210210
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
211211

@@ -219,7 +219,7 @@ def test_llama3_405B_f16_decomposed(self):
219219
)
220220

221221
@skipif_run_quick_llama_test
222-
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
222+
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
223223
def test_llama3_405B_f16(self):
224224

225225
# Llama 3.1 405B non-decomposed
@@ -236,12 +236,12 @@ def test_llama3_405B_f16(self):
236236
f"--iree-hip-target={self.iree_hip_target}",
237237
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
238238
f"--attention-kernel=torch_sdpa",
239-
f"--num-prompts={self.bs}",
239+
f"--num-prompts={self.batch_size}",
240240
]
241241
)
242242

243243
baseline_mean_perplexity = round(
244-
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
244+
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
245245
)
246246
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
247247

@@ -255,7 +255,7 @@ def test_llama3_405B_f16(self):
255255
)
256256

257257
@skipif_run_quick_llama_test
258-
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
258+
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
259259
def test_llama3_405B_fp8_decomposed(self):
260260

261261
# Llama 3.1 405B decomposed
@@ -272,12 +272,12 @@ def test_llama3_405B_fp8_decomposed(self):
272272
f"--iree-hip-target={self.iree_hip_target}",
273273
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
274274
f"--attention-kernel=decomposed",
275-
f"--num-prompts={self.bs}",
275+
f"--num-prompts={self.batch_size}",
276276
]
277277
)
278278

279279
baseline_mean_perplexity = round(
280-
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
280+
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
281281
)
282282
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
283283

@@ -291,7 +291,7 @@ def test_llama3_405B_fp8_decomposed(self):
291291
)
292292

293293
@skipif_run_quick_llama_test
294-
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
294+
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
295295
def test_llama3_405B_fp8(self):
296296

297297
# Llama 3.1 405B non-decomposed
@@ -308,12 +308,12 @@ def test_llama3_405B_fp8(self):
308308
f"--iree-hip-target={self.iree_hip_target}",
309309
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
310310
f"--attention-kernel=torch_sdpa",
311-
f"--num-prompts={self.bs}",
311+
f"--num-prompts={self.batch_size}",
312312
]
313313
)
314314

315315
baseline_mean_perplexity = round(
316-
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
316+
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
317317
)
318318
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
319319

0 commit comments

Comments
 (0)