From 58f258d981812599e20faeecc8a2faf05fdd3960 Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Mon, 16 Dec 2024 12:04:43 -0800 Subject: [PATCH] Update llama tests for block size 32 (#696) The block_seq_stride default is changing to 32 instead of 16, so this PR updates the tests to use the block_seq_stride flag and the new numpy inputs for block size 32 to benchmark correctly. This PR also removes the decomposed fp16 tests that are not needed anymore. --------- Signed-off-by: aviator19941 --- sharktank/sharktank/utils/export_artifacts.py | 5 +- .../models/llama/benchmark_amdgpu_test.py | 215 +++--------------- 2 files changed, 32 insertions(+), 188 deletions(-) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index ec75d597e..0bf252525 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -92,6 +92,7 @@ def __init__( iree_hal_target_backends: str, attention_kernel: str, tensor_parallelism_size: int, + block_seq_stride: Optional[int] = None, ): self.sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent @@ -102,6 +103,7 @@ def __init__( self.iree_hal_target_backends = iree_hal_target_backends self.attention_kernel = attention_kernel self.tensor_parallelism_size = tensor_parallelism_size + self.block_seq_stride = block_seq_stride def timeit(func): def wrapper(*args, **kwargs): @@ -184,6 +186,8 @@ def export_to_mlir( if self.attention_kernel in ["decomposed", "torch"]: export_args.append("--attention-kernel") export_args.append(self.attention_kernel) + if self.block_seq_stride: + export_args.append(f"--block-seq-stride={self.block_seq_stride}") cwd = self.sharktank_dir cmd = subprocess.list2cmdline(export_args) @@ -280,7 +284,6 @@ def iree_benchmark_vmfb( benchmark_args += [ "iree-benchmark-module", "--hip_use_streams=true", - "--hip_allow_inline_execution=true", "--device_allocator=caching", f"--module={vmfb_name}", ] diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 13a2c35e4..0c45bdffa 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -74,14 +74,6 @@ def setUp(self): self.dir_path_8b = self.dir_path / "llama-8b" self.temp_dir_8b = Path(self.dir_path_8b) self.temp_dir_8b.mkdir(parents=True, exist_ok=True) - self.llama8b_f16_decomposed_artifacts = ExportArtifacts( - irpa_path=str(self.irpa_path), - batch_size=4, - iree_hip_target="gfx942", - iree_hal_target_backends="rocm", - attention_kernel="decomposed", - tensor_parallelism_size=self.tensor_parallelism_size, - ) self.llama8b_f16_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path), batch_size=4, @@ -89,6 +81,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=32, ) self.llama8b_fp8_decomposed_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -97,6 +90,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="decomposed", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=32, ) self.llama8b_fp8_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -105,48 +99,42 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=32, ) - self.prefill_args_f16 = self.artifacts_dir / "prefill_args" - self.prefill_args_bs4_128_in_tokens_f16 = ( - self.artifacts_dir / "prefill_args_bs4_128" + self.prefill_args_bs4_128_in_tokens_stride_32_f16 = ( + self.artifacts_dir / "prefill_args_bs4_128_stride_32" ) self.prefill_args_bs4_2048_in_tokens_f16 = ( self.artifacts_dir / "prefill_args_bs4_2048" ) - self.decode_args_f16 = self.artifacts_dir / "decode_args" + self.decode_args_bs4_128_in_tokens_stride_32_f16 = ( + self.artifacts_dir / "decode_args_bs4_128_stride_32" + ) self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8" self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8" - self.iree_run_prefill_args = [ - "--function=prefill_bs4", - f"--input=@{self.prefill_args_f16}/tokens.npy", - f"--input=@{self.prefill_args_f16}/seq_lens.npy", - f"--input=@{self.prefill_args_f16}/seq_block_ids.npy", - f"--input=@{self.prefill_args_f16}/cache_state_f16.npy", - "--benchmark_repetitions=3", - ] self.iree_run_prefill_nondecomposed_args_fp16 = [ "--function=prefill_bs4", - f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/random_tokens.npy", - f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_lens.npy", - f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_block_ids.npy", - f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/cs_f16.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_stride_32_f16}/tokens.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_stride_32_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_stride_32_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_stride_32_f16}/cs_f16.npy", "--benchmark_repetitions=3", ] self.iree_run_prefill_nondecomposed_args_fp16_2048 = [ "--function=prefill_bs4", - f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/tokens_2048.npy", + f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/tokens.npy", f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/seq_lens.npy", f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/seq_block_ids.npy", f"--input=@{self.prefill_args_bs4_2048_in_tokens_f16}/cs_f16.npy", "--benchmark_repetitions=3", ] - self.iree_run_decode_args = [ + self.iree_run_decode_nondecomposed_args_f16 = [ "--function=decode_bs4", - f"--input=@{self.decode_args_f16}/tokens.npy", - f"--input=@{self.decode_args_f16}/seq_lens.npy", - f"--input=@{self.decode_args_f16}/start_positions.npy", - f"--input=@{self.decode_args_f16}/seq_block_ids.npy", - f"--input=@{self.decode_args_f16}/cache_state_f16.npy", + f"--input=@{self.decode_args_bs4_128_in_tokens_stride_32_f16}/next_tokens.npy", + f"--input=@{self.decode_args_bs4_128_in_tokens_stride_32_f16}/seq_lens.npy", + f"--input=@{self.decode_args_bs4_128_in_tokens_stride_32_f16}/start_positions.npy", + f"--input=@{self.decode_args_bs4_128_in_tokens_stride_32_f16}/seq_block_ids.npy", + f"--input=@{self.decode_args_bs4_128_in_tokens_stride_32_f16}/cs_f16.npy", "--benchmark_repetitions=3", ] self.iree_run_prefill_args_fp8 = [ @@ -167,46 +155,6 @@ def setUp(self): "--benchmark_repetitions=3", ] - def testBenchmark8B_f16_Decomposed(self): - output_file_name = self.dir_path_8b / "f16_decomposed" - output_mlir = self.llama8b_f16_decomposed_artifacts.create_file( - suffix=".mlir", prefix=output_file_name - ) - output_json = self.llama8b_f16_decomposed_artifacts.create_file( - suffix=".json", prefix=output_file_name - ) - output_vmfb = self.llama8b_f16_decomposed_artifacts.create_file( - suffix=".vmfb", prefix=output_file_name - ) - export_return_code = self.llama8b_f16_decomposed_artifacts.export_to_mlir( - mlir_path=output_mlir, - json_path=output_json, - ) - self.llama8b_f16_decomposed_artifacts.compile_to_vmfb( - mlir_path=str(output_mlir), - vmfb_path=output_vmfb, - hal_dump_path=output_file_name, - cwd=self.repo_root, - args=self.compile_args, - ) - # benchmark prefill - self.llama8b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_prefill_args, - cwd=self.repo_root, - ) - # benchmark decode - self.llama8b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_decode_args, - cwd=self.repo_root, - ) - - @skipif_run_quick_llama_test def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_128(self): output_file_name = self.dir_path_8b / "f16_torch_prefill_128" output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( @@ -218,7 +166,6 @@ def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_128(self): output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file( suffix=".vmfb", prefix=output_file_name ) - self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch" export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir( mlir_path=output_mlir, json_path=output_json, @@ -252,7 +199,7 @@ def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_2048(self): output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file( suffix=".vmfb", prefix=output_file_name ) - self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch" + self.llama8b_f16_torch_sdpa_artifacts.block_seq_stride = 16 export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir( mlir_path=output_mlir, json_path=output_json, @@ -275,7 +222,6 @@ def testBenchmark8B_f16_Non_Decomposed_Prefill_Input_Len_2048(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def testBenchmark8B_f16_Non_Decomposed(self): output_file_name = self.dir_path_8b / "f16_torch" output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( @@ -287,7 +233,6 @@ def testBenchmark8B_f16_Non_Decomposed(self): output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file( suffix=".vmfb", prefix=output_file_name ) - self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch" export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir( mlir_path=output_mlir, json_path=output_json, @@ -304,7 +249,7 @@ def testBenchmark8B_f16_Non_Decomposed(self): hip_device_id=self.iree_device, vmfb_name=output_vmfb, irpa_path=self.irpa_path, - args=self.iree_run_prefill_args, + args=self.iree_run_prefill_nondecomposed_args_fp16, cwd=self.repo_root, ) # benchmark decode @@ -312,7 +257,7 @@ def testBenchmark8B_f16_Non_Decomposed(self): hip_device_id=self.iree_device, vmfb_name=output_vmfb, irpa_path=self.irpa_path, - args=self.iree_run_decode_args, + args=self.iree_run_decode_nondecomposed_args_f16, cwd=self.repo_root, ) @@ -410,14 +355,6 @@ def setUp(self): self.dir_path_70b = self.dir_path / "llama-70b" self.temp_dir_70b = Path(self.dir_path_70b) self.temp_dir_70b.mkdir(parents=True, exist_ok=True) - self.llama70b_f16_decomposed_artifacts = ExportArtifacts( - irpa_path=str(self.irpa_path), - batch_size=4, - iree_hip_target="gfx942", - iree_hal_target_backends="rocm", - attention_kernel="decomposed", - tensor_parallelism_size=self.tensor_parallelism_size, - ) self.llama70b_f16_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path), batch_size=4, @@ -425,6 +362,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.llama70b_fp8_decomposed_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -433,6 +371,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="decomposed", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.llama70b_fp8_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -441,6 +380,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.prefill_args_f16 = self.artifacts_dir / "prefill_args" self.prefill_args_bs4_128_in_tokens_f16 = ( @@ -495,52 +435,6 @@ def setUp(self): @pytest.mark.xfail( reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException ) - def testBenchmark70B_f16_TP8_Decomposed(self): - output_file_name = self.dir_path_70b / "f16_decomposed" - output_mlir = self.llama70b_f16_decomposed_artifacts.create_file( - suffix=".mlir", prefix=output_file_name - ) - output_json = self.llama70b_f16_decomposed_artifacts.create_file( - suffix=".json", prefix=output_file_name - ) - output_vmfb = self.llama70b_f16_decomposed_artifacts.create_file( - suffix=".vmfb", prefix=output_file_name - ) - output_shard_file_name = ( - self.artifacts_dir - / f"fp16/tp8/llama3.1_70b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa" - ) - if output_shard_file_name.exists(): - self.irpa_path = output_shard_file_name - export_return_code = self.llama70b_f16_decomposed_artifacts.export_to_mlir( - mlir_path=output_mlir, - json_path=output_json, - ) - self.llama70b_f16_decomposed_artifacts.compile_to_vmfb( - mlir_path=str(output_mlir), - vmfb_path=output_vmfb, - hal_dump_path=output_file_name, - cwd=self.repo_root, - args=self.compile_args, - ) - # benchmark prefill - self.llama70b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_prefill_args, - cwd=self.repo_root, - ) - # benchmark decode - self.llama70b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_decode_args, - cwd=self.repo_root, - ) - - @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) def testBenchmark70B_f16_TP8_Non_Decomposed(self): output_file_name = self.dir_path_70b / "f16_torch" output_mlir = self.llama70b_f16_torch_sdpa_artifacts.create_file( @@ -697,14 +591,6 @@ def setUp(self): self.dir_path_405b = self.dir_path / "llama-405b" self.temp_dir_405b = Path(self.dir_path_405b) self.temp_dir_405b.mkdir(parents=True, exist_ok=True) - self.llama405b_f16_decomposed_artifacts = ExportArtifacts( - irpa_path=str(self.irpa_path), - batch_size=4, - iree_hip_target="gfx942", - iree_hal_target_backends="rocm", - attention_kernel="decomposed", - tensor_parallelism_size=self.tensor_parallelism_size, - ) self.llama405b_f16_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path), batch_size=4, @@ -712,6 +598,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.llama405b_fp8_decomposed_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -720,6 +607,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="decomposed", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.llama405b_fp8_torch_sdpa_artifacts = ExportArtifacts( irpa_path=str(self.irpa_path_fp8), @@ -728,6 +616,7 @@ def setUp(self): iree_hal_target_backends="rocm", attention_kernel="torch", tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=16, ) self.prefill_args_f16 = self.artifacts_dir / "prefill_args" self.prefill_args_bs4_128_in_tokens_f16 = ( @@ -779,54 +668,6 @@ def setUp(self): "--benchmark_repetitions=3", ] - @pytest.mark.xfail( - reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException - ) - def testBenchmark405B_f16_TP8_Decomposed(self): - output_file_name = self.dir_path_405b / "f16_decomposed" - output_mlir = self.llama405b_f16_decomposed_artifacts.create_file( - suffix=".mlir", prefix=output_file_name - ) - output_json = self.llama405b_f16_decomposed_artifacts.create_file( - suffix=".json", prefix=output_file_name - ) - output_vmfb = self.llama405b_f16_decomposed_artifacts.create_file( - suffix=".vmfb", prefix=output_file_name - ) - output_shard_file_name = ( - self.artifacts_dir - / f"fp16/tp8/llama3.1_405b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa" - ) - if output_shard_file_name.exists(): - self.irpa_path = output_shard_file_name - export_return_code = self.llama405b_f16_decomposed_artifacts.export_to_mlir( - mlir_path=output_mlir, - json_path=output_json, - ) - self.llama405b_f16_decomposed_artifacts.compile_to_vmfb( - mlir_path=str(output_mlir), - vmfb_path=output_vmfb, - hal_dump_path=output_file_name, - cwd=self.repo_root, - args=self.compile_args, - ) - # benchmark prefill - self.llama405b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_prefill_args, - cwd=self.repo_root, - ) - # benchmark decode - self.llama405b_f16_decomposed_artifacts.iree_benchmark_vmfb( - hip_device_id=self.iree_device, - vmfb_name=output_vmfb, - irpa_path=self.irpa_path, - args=self.iree_run_decode_args, - cwd=self.repo_root, - ) - @pytest.mark.xfail( reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException )