Skip to content

Commit d05179c

Browse files
committed
[ghstack] Add support for more shapes
ghstack-source-id: 31248998d0462276b32501d6f913978e02a1a096 ghstack-comment-id: 2779402838 Pull Request resolved: #2021
1 parent f66dbd8 commit d05179c

File tree

6 files changed

+251
-80
lines changed

6 files changed

+251
-80
lines changed

benchmarks/microbenchmarks/README.md

+17
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ model_params:
4646
[2048, 4096, 1024],
4747
[4096, 4096, 1024]
4848
]
49+
- name: "llama"
50+
- name: "pow2"
51+
min_power: 10 # Optional, default is 10 (1024)
52+
max_power: 14 # Optional, default is 14 (16,384)
53+
- name: "pow2_extended"
54+
min_power: 10 # Optional, default is 10 (1024)
55+
max_power: 14 # Optional, default is 14 (16,384)
56+
- name: "sweep"
57+
min_power: 8 # Optional, default is 8 (256)
58+
max_power: 15 # Optional, default is 15 (32,768)
4959
high_precision_dtype: "torch.bfloat16"
5060
compile: "max-autotune" # Options: "default", "max-autotune", "false"
5161
device: "cuda" # Options: "cuda", "mps", "xpu", "cpu"
@@ -54,6 +64,13 @@ model_params:
5464
5565
## Configuration Options
5666
67+
### Shape Generation Options
68+
- `custom`: Manually specify shapes as a list of [m, k, n] dimensions
69+
- `llama`: Use LLaMa 2 70B single-node weight shapes (assumes fused attn.wqkv and ffn.w13)
70+
- `pow2`: Generate shapes with dimensions that are powers of 2 (e.g., 1024, 2048, 4096, etc.)
71+
- `pow2_extended`: Generate shapes with dimensions that are powers of 2 and powers of 2 + half (e.g., 1024, 1536, 2048, 3072, etc.)
72+
- `sweep`: Generate a sweep of shapes with different powers of 2 for M, K, N dimensions
73+
5774
### Quantization Methods
5875
Currently, quantization string is in same format as the one being passed in llama/generate.py.
5976
- `baseline`: No quantization

benchmarks/microbenchmarks/benchmark_runner.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,50 @@ def get_shapes_for_config(
4848
name = shape_config["name"]
4949
if name == "custom":
5050
shapes.extend([(name, shape) for shape in shape_config["shapes"]])
51+
elif name == "llama":
52+
# LLaMa 2 70B single-node weight shapes
53+
# assumes fused attn.wqkv and ffn.w13
54+
bsz, seq_len = 4, 4096
55+
M = bsz * seq_len
56+
llama_shapes = {
57+
"attn.wqkv": (M, 8192, 1280),
58+
"attn.w0": (M, 1024, 8192),
59+
"ffn.w13": (M, 8192, 7168),
60+
"ffn.w2": (M, 3584, 8192),
61+
}
62+
shapes.extend([(f"{name}_{k}", v) for k, v in llama_shapes.items()])
63+
elif name == "pow2":
64+
# Generate shapes with dimensions that are powers of 2
65+
min_power_of_2 = shape_config.get("min_power", 10) # 1024
66+
max_power_of_2 = shape_config.get("max_power", 14) # 16,384
67+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
68+
val = 2**power_of_2
69+
shapes.append((f"{name}_{idx}", [val, val, val]))
70+
elif name == "pow2_extended":
71+
# Generate shapes with dimensions that are powers of 2 and powers of 2 + half
72+
min_power_of_2 = shape_config.get("min_power", 10) # 1024
73+
max_power_of_2 = shape_config.get("max_power", 14) # 16,384
74+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
75+
val1 = 2**power_of_2
76+
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
77+
shapes.append((f"{name}_{idx*2}", [val1, val1, val1]))
78+
shapes.append((f"{name}_{idx*2+1}", [val2, val2, val2]))
79+
elif name == "sweep":
80+
# Generate a sweep of shapes with different powers of 2 for M, K, N
81+
min_p2 = shape_config.get("min_power", 8) # 256
82+
max_p2 = shape_config.get("max_power", 15) # 32,768
83+
counter = 0
84+
for M_p2 in range(min_p2, max_p2 + 1):
85+
M = 2**M_p2
86+
for K_p2 in range(min_p2, max_p2 + 1):
87+
K = 2**K_p2
88+
for N_p2 in range(min_p2, max_p2 + 1):
89+
N = 2**N_p2
90+
shapes.append((f"{name}_{counter}", [M, K, N]))
91+
counter += 1
5192
else:
5293
raise NotImplementedError(
53-
f"Shape config {name} not supported. Currently only supports custom shapes."
94+
f"Shape config {name} not supported. Supported options: custom, llama, pow2, pow2_extended, sweep."
5495
)
5596
return shapes
5697

benchmarks/microbenchmarks/test/benchmark_config.yml

+55-39
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,29 @@
22
benchmark_mode: "inference"
33
quantization_config_recipe_names:
44
# Will run a baseline inference for model by default, without quantization for comparison
5-
# - "int4wo-32"
5+
- "int4wo-32"
66
# - "marlin"
77
- "int8wo"
8+
- "int8dq"
9+
- "float8dq"
810
# sparsity_config_recipe_names:
911
# Will run a baseline inference for model by default, without sparsity for comparison
1012
# - "semi-sparse"
1113
# - "block"
1214
output_dir: "benchmarks/microbenchmarks/results"
1315
model_params:
14-
# - name: "small_bf16_linear"
15-
# matrix_shapes:
16-
# - name: "custom"
17-
# shapes: [
18-
# [1024, 1024, 1024], # [m, k, n]
19-
# ]
20-
# high_precision_dtype: "torch.bfloat16"
21-
# use_torch_compile: true
22-
# torch_compile_mode: "max-autotune"
23-
# device: "cuda"
24-
# model_type: "linear"
25-
# enable_profiler: true # Enable profiling for this model
16+
- name: "small_bf16_linear"
17+
matrix_shapes:
18+
- name: "custom"
19+
shapes: [
20+
[1024, 1024, 1024], # [m, k, n]
21+
]
22+
high_precision_dtype: "torch.bfloat16"
23+
use_torch_compile: true
24+
torch_compile_mode: "max-autotune"
25+
device: "cuda"
26+
model_type: "linear"
27+
enable_profiler: true # Enable profiling for this model
2628

2729
- name: "large_bf16_ln_linear"
2830
matrix_shapes:
@@ -31,6 +33,20 @@ model_params:
3133
[2048, 4096, 1024],
3234
# [4096, 4096, 1024]
3335
]
36+
# Example of using LLaMa shapes
37+
- name: "llama"
38+
# Example of using power of 2 shapes
39+
- name: "pow2"
40+
min_power: 10 # 1024
41+
max_power: 12 # 4096
42+
# Example of using extended power of 2 shapes
43+
- name: "pow2_extended"
44+
min_power: 10 # 1024
45+
max_power: 11 # 2048
46+
# Example of using sweep shapes (commented out as it generates many shapes)
47+
# - name: "sweep"
48+
# min_power: 8 # 256
49+
# max_power: 9 # 512
3450
high_precision_dtype: "torch.bfloat16"
3551
use_torch_compile: true
3652
torch_compile_mode: "max-autotune"
@@ -51,30 +67,30 @@ model_params:
5167
# model_type: "linear"
5268
# enable_profiler: true # Enable profiling for this model
5369

54-
- name: "bf16_rms_norm_linear_activation"
55-
matrix_shapes:
56-
- name: "custom"
57-
shapes: [
58-
[2048, 4096, 1024],
59-
]
60-
high_precision_dtype: "torch.bfloat16"
61-
use_torch_compile: true
62-
torch_compile_mode: "max-autotune"
63-
device: "cuda"
64-
model_type: "rms_norm_linear_activation"
65-
enable_profiler: true
66-
enable_memory_profile: true
70+
# - name: "bf16_rms_norm_linear_activation"
71+
# matrix_shapes:
72+
# - name: "custom"
73+
# shapes: [
74+
# [2048, 4096, 1024],
75+
# ]
76+
# high_precision_dtype: "torch.bfloat16"
77+
# use_torch_compile: true
78+
# torch_compile_mode: "max-autotune"
79+
# device: "cuda"
80+
# model_type: "rms_norm_linear_activation"
81+
# enable_profiler: true
82+
# enable_memory_profile: true
6783

68-
- name: "bf16_transformer_block"
69-
matrix_shapes:
70-
- name: "custom"
71-
shapes: [
72-
[2048, 4096, 1024], # For transformer_block, k is the hidden dimension
73-
]
74-
high_precision_dtype: "torch.bfloat16"
75-
use_torch_compile: true
76-
torch_compile_mode: "max-autotune"
77-
device: "cuda"
78-
model_type: "transformer_block"
79-
enable_profiler: true
80-
enable_memory_profile: true
84+
# - name: "bf16_transformer_block"
85+
# matrix_shapes:
86+
# - name: "custom"
87+
# shapes: [
88+
# [2048, 4096, 1024], # For transformer_block, k is the hidden dimension
89+
# ]
90+
# high_precision_dtype: "torch.bfloat16"
91+
# use_torch_compile: true
92+
# torch_compile_mode: "max-autotune"
93+
# device: "cuda"
94+
# model_type: "transformer_block" # TODO: Add a custom model (Figure out how to do this, maybe pass a .py file with model definition)
95+
# enable_profiler: true
96+
# enable_memory_profile: true

benchmarks/microbenchmarks/test/test_benchmark_runner.py

+60
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,72 @@ def tearDown(self):
5757
shutil.rmtree(self.temp_dir)
5858

5959
def test_get_shapes_for_config(self):
60+
# Test custom shapes
6061
shapes = get_shapes_for_config(
6162
self.test_config["model_params"][0]["matrix_shapes"]
6263
)
6364
self.assertEqual(len(shapes), 1)
6465
self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024]))
6566

67+
# Test llama shapes
68+
llama_shapes = get_shapes_for_config([{"name": "llama"}])
69+
self.assertEqual(len(llama_shapes), 4) # 4 LLaMa shapes
70+
self.assertTrue(
71+
any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes)
72+
)
73+
self.assertTrue(
74+
any(name.startswith("llama_attn.w0") for name, _ in llama_shapes)
75+
)
76+
self.assertTrue(
77+
any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes)
78+
)
79+
self.assertTrue(
80+
any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes)
81+
)
82+
83+
# Test pow2 shapes
84+
pow2_shapes = get_shapes_for_config(
85+
[{"name": "pow2", "min_power": 10, "max_power": 12}]
86+
)
87+
self.assertEqual(len(pow2_shapes), 3) # 3 powers of 2 (10, 11, 12)
88+
self.assertEqual(pow2_shapes[0], ("pow2_0", [1024, 1024, 1024])) # 2^10
89+
self.assertEqual(pow2_shapes[1], ("pow2_1", [2048, 2048, 2048])) # 2^11
90+
self.assertEqual(pow2_shapes[2], ("pow2_2", [4096, 4096, 4096])) # 2^12
91+
92+
# Test pow2_extended shapes
93+
pow2_extended_shapes = get_shapes_for_config(
94+
[{"name": "pow2_extended", "min_power": 10, "max_power": 11}]
95+
)
96+
self.assertEqual(
97+
len(pow2_extended_shapes), 4
98+
) # 2 powers of 2, each with 2 variants
99+
self.assertEqual(
100+
pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024])
101+
) # 2^10
102+
self.assertEqual(
103+
pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536])
104+
) # 2^10 + 2^9
105+
self.assertEqual(
106+
pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048])
107+
) # 2^11
108+
self.assertEqual(
109+
pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072])
110+
) # 2^11 + 2^10
111+
112+
# Test sweep shapes (limited to a small range for testing)
113+
sweep_shapes = get_shapes_for_config(
114+
[{"name": "sweep", "min_power": 8, "max_power": 9}]
115+
)
116+
# For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations)
117+
self.assertEqual(len(sweep_shapes), 8)
118+
# Check that all shapes have the expected format
119+
for name, shape in sweep_shapes:
120+
self.assertTrue(name.startswith("sweep_"))
121+
self.assertEqual(len(shape), 3) # [M, K, N]
122+
# Check that all dimensions are powers of 2 between 2^8 and 2^9
123+
for dim in shape:
124+
self.assertTrue(dim in [256, 512]) # 2^8, 2^9
125+
66126
def test_get_param_combinations(self):
67127
model_param = self.test_config["model_params"][0]
68128
shapes, params = get_param_combinations(model_param)

0 commit comments

Comments
 (0)