Skip to content

Commit a724a37

Browse files
authored
Model shapes config (#2036)
1 parent d06b3e3 commit a724a37

12 files changed

+468
-77
lines changed

benchmarks/microbenchmarks/README.md

+61-1
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,74 @@ Currently, quantization string is in same format as the one being passed in llam
6363

6464
### Model Types
6565
- `linear`: Simple linear layer
66-
- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid
66+
- `ln_linear_<activation>`: LayerNorm + Linear + Activation, where activation can be:
67+
- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid
68+
- `ln_linear_relu`: LayerNorm + Linear + ReLU
69+
- `ln_linear_leakyrelu`: LayerNorm + Linear + LeakyReLU
70+
- `ln_linear_relu6`: LayerNorm + Linear + ReLU6
71+
- `ln_linear_gelu`: LayerNorm + Linear + GELU
72+
- `ln_linear_silu`: LayerNorm + Linear + SiLU
73+
- `ln_linear_hardswish`: LayerNorm + Linear + Hardswish
74+
- `transformer_block`: Transformer block with self-attention and MLP
6775

6876
### Device Options
6977
- `cuda`: NVIDIA GPU
7078
- `xpu`: Intel GPU
7179
- `mps`: Apple Silicon GPU
7280
- `cpu`: CPU fallback
7381

82+
### Shape Generation Options
83+
- `custom`: Manually specify shapes as a list of [m, k, n] dimensions
84+
```yaml
85+
matrix_shapes:
86+
- name: "custom"
87+
shapes: [
88+
[1024, 1024, 1024], # [m, k, n]
89+
[2048, 4096, 1024]
90+
]
91+
```
92+
93+
- `llama`: Use LLaMa 2 70B single-node weight shapes (assumes fused attn.wqkv and ffn.w13)
94+
- Generates shapes for: "attn.wqkv", "attn.w0", "ffn.w13", "ffn.w2"
95+
```yaml
96+
matrix_shapes:
97+
- name: "llama"
98+
```
99+
100+
- `pow2`: Generate shapes with dimensions that are powers of 2
101+
- Parameters:
102+
- `min_power`: Minimum power of 2 (default: 10, which is 1024)
103+
- `max_power`: Maximum power of 2 (default: 14, which is 16,384)
104+
```yaml
105+
matrix_shapes:
106+
- name: "pow2"
107+
min_power: 10 # 2^10 = 1024
108+
max_power: 12 # 2^12 = 4096
109+
```
110+
111+
- `pow2_extended`: Generate shapes with dimensions that are powers of 2 and powers of 2 + half
112+
- Parameters:
113+
- `min_power`: Minimum power of 2 (default: 10, which is 1024)
114+
- `max_power`: Maximum power of 2 (default: 14, which is 16,384)
115+
```yaml
116+
matrix_shapes:
117+
- name: "pow2_extended"
118+
min_power: 10 # Generates: 1024, 1536, 2048, 3072, etc.
119+
max_power: 11
120+
```
121+
122+
- `sweep`: Generate a sweep of shapes with different powers of 2 for M, K, N dimensions
123+
- Parameters:
124+
- `min_power`: Minimum power of 2 (default: 8, which is 256)
125+
- `max_power`: Maximum power of 2 (default: 15, which is 32,768)
126+
- Note: This generates all combinations of M, K, N dimensions, which can be a large number of shapes
127+
```yaml
128+
matrix_shapes:
129+
- name: "sweep"
130+
min_power: 8 # 2^8 = 256
131+
max_power: 9 # 2^9 = 512
132+
```
133+
74134
## Output
75135

76136
Results are saved to a CSV file in the specified output directory

benchmarks/microbenchmarks/benchmark_inference.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
BenchmarkConfig,
2323
BenchmarkResult,
2424
clean_caches,
25-
create_model_and_input,
2625
model_inference_time_in_ms,
2726
string_to_config,
2827
)
2928
from torchao.quantization import quantize_
3029
from torchao.sparsity.sparse_api import sparsify_
30+
from torchao.testing.model_architectures import (
31+
create_model_and_input_data,
32+
)
3133

3234

3335
def run(config: BenchmarkConfig) -> BenchmarkResult:
@@ -38,7 +40,7 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
3840
# Create output directory if it doesn't exist
3941
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
4042

41-
base_model, input_data = create_model_and_input(
43+
base_model, input_data = create_model_and_input_data(
4244
config.model_type,
4345
config.m,
4446
config.k,

benchmarks/microbenchmarks/benchmark_runner.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,50 @@ def get_shapes_for_config(
4848
name = shape_config["name"]
4949
if name == "custom":
5050
shapes.extend([(name, shape) for shape in shape_config["shapes"]])
51+
elif name == "llama":
52+
# LLaMa 2 70B single-node weight shapes
53+
# assumes fused attn.wqkv and ffn.w13
54+
bsz, seq_len = 4, 4096
55+
M = bsz * seq_len
56+
llama_shapes = {
57+
"attn.wqkv": (M, 8192, 1280),
58+
"attn.w0": (M, 1024, 8192),
59+
"ffn.w13": (M, 8192, 7168),
60+
"ffn.w2": (M, 3584, 8192),
61+
}
62+
shapes.extend([(f"{name}_{k}", v) for k, v in llama_shapes.items()])
63+
elif name == "pow2":
64+
# Generate shapes with dimensions that are powers of 2
65+
min_power_of_2 = shape_config.get("min_power", 10) # 1024
66+
max_power_of_2 = shape_config.get("max_power", 14) # 16,384
67+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
68+
val = 2**power_of_2
69+
shapes.append((f"{name}_{idx}", [val, val, val]))
70+
elif name == "pow2_extended":
71+
# Generate shapes with dimensions that are powers of 2 and powers of 2 + half
72+
min_power_of_2 = shape_config.get("min_power", 10) # 1024
73+
max_power_of_2 = shape_config.get("max_power", 14) # 16,384
74+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
75+
val1 = 2**power_of_2
76+
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
77+
shapes.append((f"{name}_{idx * 2}", [val1, val1, val1]))
78+
shapes.append((f"{name}_{idx * 2 + 1}", [val2, val2, val2]))
79+
elif name == "sweep":
80+
# Generate a sweep of shapes with different powers of 2 for M, K, N
81+
min_p2 = shape_config.get("min_power", 8) # 256
82+
max_p2 = shape_config.get("max_power", 15) # 32,768
83+
counter = 0
84+
for M_p2 in range(min_p2, max_p2 + 1):
85+
M = 2**M_p2
86+
for K_p2 in range(min_p2, max_p2 + 1):
87+
K = 2**K_p2
88+
for N_p2 in range(min_p2, max_p2 + 1):
89+
N = 2**N_p2
90+
shapes.append((f"{name}_{counter}", [M, K, N]))
91+
counter += 1
5192
else:
5293
raise NotImplementedError(
53-
f"Shape config {name} not supported. Currently only supports custom shapes."
94+
f"Shape config {name} not supported. Supported options: custom, llama, pow2, pow2_extended, sweep."
5495
)
5596
return shapes
5697

benchmarks/microbenchmarks/test/benchmark_config.yml

+45
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,48 @@ model_params:
2626
device: "cuda"
2727
model_type: "linear"
2828
enable_profiler: true # Enable profiling for this model
29+
30+
- name: "ln_linear_sigmoid_cuda"
31+
matrix_shapes:
32+
- name: "custom"
33+
shapes: [
34+
[2048, 4096, 1024],
35+
]
36+
high_precision_dtype: "torch.bfloat16"
37+
use_torch_compile: true
38+
torch_compile_mode: "max-autotune"
39+
device: "cuda"
40+
model_type: "ln_linear_sigmoid"
41+
enable_profiler: true
42+
43+
- name: "bf16_transformer_block"
44+
matrix_shapes:
45+
- name: "custom"
46+
shapes: [
47+
[2048, 4096, 1024], # For transformer_block, k is the hidden dimension
48+
]
49+
high_precision_dtype: "torch.bfloat16"
50+
use_torch_compile: true
51+
torch_compile_mode: "max-autotune"
52+
device: "cuda"
53+
model_type: "transformer_block" # TODO: Add a custom model (Figure out how to do this, maybe pass a .py file with model definition)
54+
enable_profiler: true
55+
56+
- name: "large_bf16_ln_linear"
57+
matrix_shapes:
58+
- name: "llama" # Example of using LLaMa shapes
59+
- name: "pow2" # Example of using power of 2 shapes
60+
min_power: 10 # 1024
61+
max_power: 12 # 4096
62+
- name: "pow2_extended" # Example of using extended power of 2 shapes
63+
min_power: 10 # 1024
64+
max_power: 11 # 2048
65+
- name: "sweep" # Example of using sweep shapes (commented out as it generates many shapes)
66+
min_power: 8 # 256
67+
max_power: 9 # 512
68+
high_precision_dtype: "torch.bfloat16"
69+
use_torch_compile: true
70+
torch_compile_mode: "max-autotune"
71+
device: "cuda"
72+
model_type: "linear"
73+
enable_profiler: true # Enable profiling for this model

benchmarks/microbenchmarks/test/test_benchmark_profiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
)
1616
from benchmarks.microbenchmarks.utils import (
1717
BenchmarkConfig,
18-
ToyLinearModel,
1918
)
19+
from torchao.testing.model_architectures import ToyLinearModel
2020

2121

2222
class TestBenchmarkProfiler(unittest.TestCase):

benchmarks/microbenchmarks/test/test_benchmark_runner.py

+60
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,72 @@ def tearDown(self):
5757
shutil.rmtree(self.temp_dir)
5858

5959
def test_get_shapes_for_config(self):
60+
# Test custom shapes
6061
shapes = get_shapes_for_config(
6162
self.test_config["model_params"][0]["matrix_shapes"]
6263
)
6364
self.assertEqual(len(shapes), 1)
6465
self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024]))
6566

67+
# Test llama shapes
68+
llama_shapes = get_shapes_for_config([{"name": "llama"}])
69+
self.assertEqual(len(llama_shapes), 4) # 4 LLaMa shapes
70+
self.assertTrue(
71+
any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes)
72+
)
73+
self.assertTrue(
74+
any(name.startswith("llama_attn.w0") for name, _ in llama_shapes)
75+
)
76+
self.assertTrue(
77+
any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes)
78+
)
79+
self.assertTrue(
80+
any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes)
81+
)
82+
83+
# Test pow2 shapes
84+
pow2_shapes = get_shapes_for_config(
85+
[{"name": "pow2", "min_power": 10, "max_power": 12}]
86+
)
87+
self.assertEqual(len(pow2_shapes), 3) # 3 powers of 2 (10, 11, 12)
88+
self.assertEqual(pow2_shapes[0], ("pow2_0", [1024, 1024, 1024])) # 2^10
89+
self.assertEqual(pow2_shapes[1], ("pow2_1", [2048, 2048, 2048])) # 2^11
90+
self.assertEqual(pow2_shapes[2], ("pow2_2", [4096, 4096, 4096])) # 2^12
91+
92+
# Test pow2_extended shapes
93+
pow2_extended_shapes = get_shapes_for_config(
94+
[{"name": "pow2_extended", "min_power": 10, "max_power": 11}]
95+
)
96+
self.assertEqual(
97+
len(pow2_extended_shapes), 4
98+
) # 2 powers of 2, each with 2 variants
99+
self.assertEqual(
100+
pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024])
101+
) # 2^10
102+
self.assertEqual(
103+
pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536])
104+
) # 2^10 + 2^9
105+
self.assertEqual(
106+
pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048])
107+
) # 2^11
108+
self.assertEqual(
109+
pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072])
110+
) # 2^11 + 2^10
111+
112+
# Test sweep shapes (limited to a small range for testing)
113+
sweep_shapes = get_shapes_for_config(
114+
[{"name": "sweep", "min_power": 8, "max_power": 9}]
115+
)
116+
# For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations)
117+
self.assertEqual(len(sweep_shapes), 8)
118+
# Check that all shapes have the expected format
119+
for name, shape in sweep_shapes:
120+
self.assertTrue(name.startswith("sweep_"))
121+
self.assertEqual(len(shape), 3) # [M, K, N]
122+
# Check that all dimensions are powers of 2 between 2^8 and 2^9
123+
for dim in shape:
124+
self.assertTrue(dim in [256, 512]) # 2^8, 2^9
125+
66126
def test_get_param_combinations(self):
67127
model_param = self.test_config["model_params"][0]
68128
shapes, params = get_param_combinations(model_param)

benchmarks/microbenchmarks/test/test_utils.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
BlockSparseWeightConfig,
1717
Float8DynamicActivationFloat8SemiSparseWeightConfig,
1818
Int4WeightOnlyConfig,
19-
LNLinearSigmoid,
2019
SemiSparseWeightConfig,
21-
ToyLinearModel,
2220
clean_caches,
23-
create_model_and_input,
2421
generate_results_csv,
2522
get_default_device,
2623
string_to_config,
2724
)
25+
from torchao.testing.model_architectures import (
26+
LNLinearActivationModel,
27+
ToyLinearModel,
28+
create_model_and_input_data,
29+
)
2830

2931

3032
class TestUtils(unittest.TestCase):
@@ -153,7 +155,7 @@ def test_toy_linear_model(self):
153155
self.assertEqual(out.dtype, torch.float32)
154156

155157
def test_ln_linear_sigmoid(self):
156-
model = LNLinearSigmoid(fc_dim1=64, fc_dim2=32, dtype=torch.float32)
158+
model = LNLinearActivationModel(fc_dim1=64, fc_dim2=32, dtype=torch.float32)
157159
x = torch.randn(16, 64)
158160
out = model(x)
159161
self.assertEqual(out.shape, (16, 32))
@@ -162,9 +164,9 @@ def test_ln_linear_sigmoid(self):
162164
torch.all((out >= 0) & (out <= 1))
163165
) # Check sigmoid output range
164166

165-
def test_create_model_and_input(self):
167+
def test_create_model_and_input_data(self):
166168
m, k, n = 16, 64, 32
167-
model, input_data = create_model_and_input(
169+
model, input_data = create_model_and_input_data(
168170
model_type="linear",
169171
m=m,
170172
k=k,
@@ -175,15 +177,15 @@ def test_create_model_and_input(self):
175177
self.assertIsInstance(model, ToyLinearModel)
176178
self.assertEqual(input_data.shape, (m, k))
177179

178-
model, input_data = create_model_and_input(
180+
model, input_data = create_model_and_input_data(
179181
model_type="ln_linear_sigmoid",
180182
m=m,
181183
k=k,
182184
n=n,
183185
high_precision_dtype=torch.float32,
184186
device="cpu",
185187
)
186-
self.assertIsInstance(model, LNLinearSigmoid)
188+
self.assertIsInstance(model, LNLinearActivationModel)
187189
self.assertEqual(input_data.shape, (m, k))
188190

189191
def test_generate_results_csv(self):

0 commit comments

Comments
 (0)