Skip to content

Commit 3553732

Browse files
committed
[ghstack] Add support for more shapes
ghstack-source-id: bf520eeb30d3f7c86007459678d96c0dc84a3e21 ghstack-comment-id: 2779402838 Pull Request resolved: #2021
1 parent 8b0184a commit 3553732

File tree

5 files changed

+119
-1
lines changed

5 files changed

+119
-1
lines changed

benchmarks/microbenchmarks/README.md

+17
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ model_params:
4646
[2048, 4096, 1024],
4747
[4096, 4096, 1024]
4848
]
49+
- name: "llama"
50+
- name: "pow2"
51+
min_power: 10 # Optional, default is 10 (1024)
52+
max_power: 14 # Optional, default is 14 (16,384)
53+
- name: "pow2_extended"
54+
min_power: 10 # Optional, default is 10 (1024)
55+
max_power: 14 # Optional, default is 14 (16,384)
56+
- name: "sweep"
57+
min_power: 8 # Optional, default is 8 (256)
58+
max_power: 15 # Optional, default is 15 (32,768)
4959
high_precision_dtype: "torch.bfloat16"
5060
compile: "max-autotune" # Options: "default", "max-autotune", "false"
5161
device: "cuda" # Options: "cuda", "mps", "xpu", "cpu"
@@ -54,6 +64,13 @@ model_params:
5464
5565
## Configuration Options
5666
67+
### Shape Generation Options
68+
- `custom`: Manually specify shapes as a list of [m, k, n] dimensions
69+
- `llama`: Use LLaMa 2 70B single-node weight shapes (assumes fused attn.wqkv and ffn.w13)
70+
- `pow2`: Generate shapes with dimensions that are powers of 2 (e.g., 1024, 2048, 4096, etc.)
71+
- `pow2_extended`: Generate shapes with dimensions that are powers of 2 and powers of 2 + half (e.g., 1024, 1536, 2048, 3072, etc.)
72+
- `sweep`: Generate a sweep of shapes with different powers of 2 for M, K, N dimensions
73+
5774
### Quantization Methods
5875
Currently, quantization string is in same format as the one being passed in llama/generate.py.
5976
- `baseline`: No quantization

benchmarks/microbenchmarks/benchmark_runner.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,50 @@ def get_shapes_for_config(
4848
name = shape_config["name"]
4949
if name == "custom":
5050
shapes.extend([(name, shape) for shape in shape_config["shapes"]])
51+
elif name == "llama":
52+
# LLaMa 2 70B single-node weight shapes
53+
# assumes fused attn.wqkv and ffn.w13
54+
bsz, seq_len = 4, 4096
55+
M = bsz * seq_len
56+
llama_shapes = {
57+
"attn.wqkv": (M, 8192, 1280),
58+
"attn.w0": (M, 1024, 8192),
59+
"ffn.w13": (M, 8192, 7168),
60+
"ffn.w2": (M, 3584, 8192),
61+
}
62+
shapes.extend([(f"{name}_{k}", v) for k, v in llama_shapes.items()])
63+
elif name == "pow2":
64+
# Generate shapes with dimensions that are powers of 2
65+
min_power_of_2 = shape_config.get("min_power", 10) # 1024
66+
max_power_of_2 = shape_config.get("max_power", 14) # 16,384
67+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
68+
val = 2**power_of_2
69+
shapes.append((f"{name}_{idx}", [val, val, val]))
70+
elif name == "pow2_extended":
71+
# Generate shapes with dimensions that are powers of 2 and powers of 2 + half
72+
min_power_of_2 = shape_config.get("min_power", 10) # 1024
73+
max_power_of_2 = shape_config.get("max_power", 14) # 16,384
74+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
75+
val1 = 2**power_of_2
76+
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
77+
shapes.append((f"{name}_{idx*2}", [val1, val1, val1]))
78+
shapes.append((f"{name}_{idx*2+1}", [val2, val2, val2]))
79+
elif name == "sweep":
80+
# Generate a sweep of shapes with different powers of 2 for M, K, N
81+
min_p2 = shape_config.get("min_power", 8) # 256
82+
max_p2 = shape_config.get("max_power", 15) # 32,768
83+
counter = 0
84+
for M_p2 in range(min_p2, max_p2 + 1):
85+
M = 2**M_p2
86+
for K_p2 in range(min_p2, max_p2 + 1):
87+
K = 2**K_p2
88+
for N_p2 in range(min_p2, max_p2 + 1):
89+
N = 2**N_p2
90+
shapes.append((f"{name}_{counter}", [M, K, N]))
91+
counter += 1
5192
else:
5293
raise NotImplementedError(
53-
f"Shape config {name} not supported. Currently only supports custom shapes."
94+
f"Shape config {name} not supported. Supported options: custom, llama, pow2, pow2_extended, sweep."
5495
)
5596
return shapes
5697

benchmarks/microbenchmarks/test/benchmark_config.yml

+14
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ model_params:
3131
[2048, 4096, 1024],
3232
# [4096, 4096, 1024]
3333
]
34+
# Example of using LLaMa shapes
35+
- name: "llama"
36+
# Example of using power of 2 shapes
37+
- name: "pow2"
38+
min_power: 10 # 1024
39+
max_power: 12 # 4096
40+
# Example of using extended power of 2 shapes
41+
- name: "pow2_extended"
42+
min_power: 10 # 1024
43+
max_power: 11 # 2048
44+
# Example of using sweep shapes (commented out as it generates many shapes)
45+
# - name: "sweep"
46+
# min_power: 8 # 256
47+
# max_power: 9 # 512
3448
high_precision_dtype: "torch.bfloat16"
3549
use_torch_compile: true
3650
torch_compile_mode: "max-autotune"

benchmarks/microbenchmarks/test/test_benchmark_runner.py

+44
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,55 @@ def tearDown(self):
5757
shutil.rmtree(self.temp_dir)
5858

5959
def test_get_shapes_for_config(self):
60+
# Test custom shapes
6061
shapes = get_shapes_for_config(
6162
self.test_config["model_params"][0]["matrix_shapes"]
6263
)
6364
self.assertEqual(len(shapes), 1)
6465
self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024]))
66+
67+
# Test llama shapes
68+
llama_shapes = get_shapes_for_config([
69+
{"name": "llama"}
70+
])
71+
self.assertEqual(len(llama_shapes), 4) # 4 LLaMa shapes
72+
self.assertTrue(any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes))
73+
self.assertTrue(any(name.startswith("llama_attn.w0") for name, _ in llama_shapes))
74+
self.assertTrue(any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes))
75+
self.assertTrue(any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes))
76+
77+
# Test pow2 shapes
78+
pow2_shapes = get_shapes_for_config([
79+
{"name": "pow2", "min_power": 10, "max_power": 12}
80+
])
81+
self.assertEqual(len(pow2_shapes), 3) # 3 powers of 2 (10, 11, 12)
82+
self.assertEqual(pow2_shapes[0], ("pow2_0", [1024, 1024, 1024])) # 2^10
83+
self.assertEqual(pow2_shapes[1], ("pow2_1", [2048, 2048, 2048])) # 2^11
84+
self.assertEqual(pow2_shapes[2], ("pow2_2", [4096, 4096, 4096])) # 2^12
85+
86+
# Test pow2_extended shapes
87+
pow2_extended_shapes = get_shapes_for_config([
88+
{"name": "pow2_extended", "min_power": 10, "max_power": 11}
89+
])
90+
self.assertEqual(len(pow2_extended_shapes), 4) # 2 powers of 2, each with 2 variants
91+
self.assertEqual(pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024])) # 2^10
92+
self.assertEqual(pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536])) # 2^10 + 2^9
93+
self.assertEqual(pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048])) # 2^11
94+
self.assertEqual(pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072])) # 2^11 + 2^10
95+
96+
# Test sweep shapes (limited to a small range for testing)
97+
sweep_shapes = get_shapes_for_config([
98+
{"name": "sweep", "min_power": 8, "max_power": 9}
99+
])
100+
# For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations)
101+
self.assertEqual(len(sweep_shapes), 8)
102+
# Check that all shapes have the expected format
103+
for name, shape in sweep_shapes:
104+
self.assertTrue(name.startswith("sweep_"))
105+
self.assertEqual(len(shape), 3) # [M, K, N]
106+
# Check that all dimensions are powers of 2 between 2^8 and 2^9
107+
for dim in shape:
108+
self.assertTrue(dim in [256, 512]) # 2^8, 2^9
65109

66110
def test_get_param_combinations(self):
67111
model_param = self.test_config["model_params"][0]

benchmarks/microbenchmarks/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ def print_results(results: List[BenchmarkResult]):
753753
result.config.name,
754754
result.config.quantization or "baseline",
755755
result.config.sparsity or "none",
756+
f"{result.config.shape_name} ({result.config.m}, {result.config.k}, {result.config.n})",
756757
f"{result.model_inference_time_in_ms:.2f}",
757758
str(result.config.enable_profiler),
758759
str(result.config.enable_memory_profile),
@@ -774,6 +775,7 @@ def print_results(results: List[BenchmarkResult]):
774775
"Name",
775776
"Quantization",
776777
"Sparsity",
778+
"Shape",
777779
"Inference Time (ms)",
778780
"Profiler Enabled",
779781
"Memory Profiling Enabled",

0 commit comments

Comments
 (0)