Skip to content

Commit

Permalink
Correct bs to batch_size
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam committed Nov 22, 2024
1 parent 2c6b191 commit 6626fa1
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions sharktank/tests/evaluate/perplexity_vmfb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def test_llama3_8B_f16_decomposed(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=decomposed",
f"--num-prompts={self.bs}",
f"--num-prompts={self.batch_size}",
]
)

baseline_mean_perplexity = round(
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)

Expand All @@ -73,7 +73,7 @@ def test_llama3_8B_f16_decomposed(self):
)

@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
def test_llama3_8B_f16(self):

# Llama 3.1 8B non-decomposed
Expand All @@ -90,12 +90,12 @@ def test_llama3_8B_f16(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=torch_sdpa",
f"--num-prompts={self.bs}",
f"--num-prompts={self.batch_size}",
]
)

baseline_mean_perplexity = round(
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)

Expand All @@ -109,7 +109,7 @@ def test_llama3_8B_f16(self):
)

@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
def test_llama3_8B_fp8_decomposed(self):

# Llama 3.1 8B decomposed
Expand All @@ -126,12 +126,12 @@ def test_llama3_8B_fp8_decomposed(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=decomposed",
f"--num-prompts={self.bs}",
f"--num-prompts={self.batch_size}",
]
)

baseline_mean_perplexity = round(
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)

Expand All @@ -145,7 +145,7 @@ def test_llama3_8B_fp8_decomposed(self):
)

@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
def test_llama3_8B_fp8(self):

# Llama 3.1 8B non-decomposed
Expand All @@ -162,12 +162,12 @@ def test_llama3_8B_fp8(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size=1",
f"--attention-kernel=torch_sdpa",
f"--num-prompts={self.bs}",
f"--num-prompts={self.batch_size}",
]
)

baseline_mean_perplexity = round(
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)

Expand Down Expand Up @@ -200,12 +200,12 @@ def test_llama3_405B_f16_decomposed(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=decomposed",
f"--num-prompts={self.bs}",
f"--num-prompts={self.batch_size}",
]
)

baseline_mean_perplexity = round(
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)

Expand All @@ -219,7 +219,7 @@ def test_llama3_405B_f16_decomposed(self):
)

@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
def test_llama3_405B_f16(self):

# Llama 3.1 405B non-decomposed
Expand All @@ -236,12 +236,12 @@ def test_llama3_405B_f16(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=torch_sdpa",
f"--num-prompts={self.bs}",
f"--num-prompts={self.batch_size}",
]
)

baseline_mean_perplexity = round(
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)

Expand All @@ -255,7 +255,7 @@ def test_llama3_405B_f16(self):
)

@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
def test_llama3_405B_fp8_decomposed(self):

# Llama 3.1 405B decomposed
Expand All @@ -272,12 +272,12 @@ def test_llama3_405B_fp8_decomposed(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=decomposed",
f"--num-prompts={self.bs}",
f"--num-prompts={self.batch_size}",
]
)

baseline_mean_perplexity = round(
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)

Expand All @@ -291,7 +291,7 @@ def test_llama3_405B_fp8_decomposed(self):
)

@skipif_run_quick_llama_test
@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
@pytest.mark.xfail(reason="Compile Error", raises=IreeCompileException)
def test_llama3_405B_fp8(self):

# Llama 3.1 405B non-decomposed
Expand All @@ -308,12 +308,12 @@ def test_llama3_405B_fp8(self):
f"--iree-hip-target={self.iree_hip_target}",
f"--tensor-parallelism-size={self.tensor_parallelism_size}",
f"--attention-kernel=torch_sdpa",
f"--num-prompts={self.bs}",
f"--num-prompts={self.batch_size}",
]
)

baseline_mean_perplexity = round(
np.mean(baseline_perplexity["perplexities"][0 : self.bs]), 6
np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
)
current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)

Expand Down

0 comments on commit 6626fa1

Please sign in to comment.