Skip to content

Commit 8b0184a

Browse files
committed
More models
ghstack-source-id: e2b23740e25f03ac769cc98182a5522f9f0399bc ghstack-comment-id: 2779402762 Pull Request resolved: #2020
1 parent 1134cd6 commit 8b0184a

File tree

3 files changed

+253
-0
lines changed

3 files changed

+253
-0
lines changed

benchmarks/microbenchmarks/test/benchmark_config.yml

+28
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,31 @@ model_params:
5050
# device: "cpu"
5151
# model_type: "linear"
5252
# enable_profiler: true # Enable profiling for this model
53+
54+
- name: "bf16_rms_norm_linear_activation"
55+
matrix_shapes:
56+
- name: "custom"
57+
shapes: [
58+
[2048, 4096, 1024],
59+
]
60+
high_precision_dtype: "torch.bfloat16"
61+
use_torch_compile: true
62+
torch_compile_mode: "max-autotune"
63+
device: "cuda"
64+
model_type: "rms_norm_linear_activation"
65+
enable_profiler: true
66+
enable_memory_profile: true
67+
68+
- name: "bf16_transformer_block"
69+
matrix_shapes:
70+
- name: "custom"
71+
shapes: [
72+
[2048, 4096, 1024], # For transformer_block, k is the hidden dimension
73+
]
74+
high_precision_dtype: "torch.bfloat16"
75+
use_torch_compile: true
76+
torch_compile_mode: "max-autotune"
77+
device: "cuda"
78+
model_type: "transformer_block"
79+
enable_profiler: true
80+
enable_memory_profile: true

benchmarks/microbenchmarks/test/test_utils.py

+115
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717
Float8DynamicActivationFloat8SemiSparseWeightConfig,
1818
Int4WeightOnlyConfig,
1919
LNLinearSigmoid,
20+
RMSNorm,
21+
RMSNormLinearActivation,
2022
SemiSparseWeightConfig,
2123
ToyLinearModel,
24+
TransformerBlock,
2225
clean_caches,
2326
create_model_and_input,
2427
generate_results_csv,
@@ -162,6 +165,61 @@ def test_ln_linear_sigmoid(self):
162165
torch.all((out >= 0) & (out <= 1))
163166
) # Check sigmoid output range
164167

168+
def test_rms_norm(self):
169+
# Test RMSNorm
170+
rms_norm = RMSNorm(dim=64)
171+
x = torch.randn(16, 64)
172+
out = rms_norm(x)
173+
self.assertEqual(out.shape, (16, 64))
174+
175+
# Test with different eps
176+
rms_norm = RMSNorm(dim=64, eps=1e-5)
177+
out = rms_norm(x)
178+
self.assertEqual(out.shape, (16, 64))
179+
180+
def test_rms_norm_linear_activation(self):
181+
# Test with default GELU activation
182+
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32)
183+
x = torch.randn(16, 64)
184+
out = model(x)
185+
self.assertEqual(out.shape, (16, 32))
186+
self.assertEqual(out.dtype, torch.float32)
187+
188+
# Test with ReLU activation
189+
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu")
190+
out = model(x)
191+
self.assertEqual(out.shape, (16, 32))
192+
self.assertTrue(torch.all(out >= 0)) # Check ReLU output range
193+
194+
# Test with SiLU activation
195+
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu")
196+
out = model(x)
197+
self.assertEqual(out.shape, (16, 32))
198+
199+
# Test with invalid activation
200+
with self.assertRaises(ValueError):
201+
RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid")
202+
203+
def test_transformer_block(self):
204+
# Test with default parameters
205+
model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32)
206+
x = torch.randn(16, 16, 64) # [batch_size, seq_len, hidden_dim]
207+
out = model(x)
208+
self.assertEqual(out.shape, (16, 16, 64))
209+
self.assertEqual(out.dtype, torch.float32)
210+
211+
# Test with different parameters
212+
model = TransformerBlock(hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32)
213+
x = torch.randn(8, 32, 128)
214+
out = model(x)
215+
self.assertEqual(out.shape, (8, 32, 128))
216+
217+
# Test with different head dimensions
218+
model = TransformerBlock(hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32)
219+
x = torch.randn(4, 8, 96)
220+
out = model(x)
221+
self.assertEqual(out.shape, (4, 8, 96))
222+
165223
def test_create_model_and_input(self):
166224
m, k, n = 16, 64, 32
167225
model, input_data = create_model_and_input(
@@ -186,6 +244,63 @@ def test_create_model_and_input(self):
186244
self.assertIsInstance(model, LNLinearSigmoid)
187245
self.assertEqual(input_data.shape, (m, k))
188246

247+
# Test RMSNormLinearActivation
248+
model, input_data = create_model_and_input(
249+
model_type="rms_norm_linear_activation",
250+
m=m,
251+
k=k,
252+
n=n,
253+
high_precision_dtype=torch.float32,
254+
device="cpu",
255+
)
256+
self.assertIsInstance(model, RMSNormLinearActivation)
257+
self.assertEqual(input_data.shape, (m, k))
258+
259+
# Test TransformerBlock
260+
model, input_data = create_model_and_input(
261+
model_type="transformer_block",
262+
m=m,
263+
k=k,
264+
n=n, # n is not used for transformer_block
265+
high_precision_dtype=torch.float32,
266+
device="cpu",
267+
)
268+
self.assertIsInstance(model, TransformerBlock)
269+
self.assertEqual(input_data.shape, (m, 16, k)) # [batch_size, seq_len, hidden_dim]
270+
271+
def test_quantization_on_models(self):
272+
# Test quantization on RMSNormLinearActivation
273+
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32)
274+
x = torch.randn(16, 64)
275+
276+
# Test with Int8WeightOnlyConfig
277+
config = string_to_config(quantization="int8wo", sparsity=None)
278+
if config is not None:
279+
# Skip quantization test if torchao.quantization.quantize is not available
280+
try:
281+
from torchao.quantization import quantize
282+
quantized_model = quantize(model, config)
283+
out = quantized_model(x)
284+
self.assertEqual(out.shape, (16, 32))
285+
except ImportError:
286+
print("Skipping quantization test: torchao.quantization.quantize not available")
287+
288+
# Test quantization on TransformerBlock
289+
model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32)
290+
x = torch.randn(16, 16, 64)
291+
292+
# Test with Int8WeightOnlyConfig
293+
config = string_to_config(quantization="int8wo", sparsity=None)
294+
if config is not None:
295+
# Skip quantization test if torchao.quantization.quantize is not available
296+
try:
297+
from torchao.quantization import quantize
298+
quantized_model = quantize(model, config)
299+
out = quantized_model(x)
300+
self.assertEqual(out.shape, (16, 16, 64))
301+
except ImportError:
302+
print("Skipping quantization test: torchao.quantization.quantize not available")
303+
189304
def test_generate_results_csv(self):
190305
results = [
191306
BenchmarkResult(

benchmarks/microbenchmarks/utils.py

+110
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,108 @@ def forward(self, x):
383383
return x
384384

385385

386+
class RMSNorm(torch.nn.Module):
387+
def __init__(self, dim, eps=1e-6, dtype=torch.bfloat16):
388+
super().__init__()
389+
self.eps = eps
390+
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
391+
392+
def forward(self, x):
393+
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
394+
return x * norm * self.weight
395+
396+
397+
class RMSNormLinearActivation(torch.nn.Module):
398+
def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="gelu"):
399+
super().__init__()
400+
self.rms_norm = RMSNorm(fc_dim1, dtype=dtype)
401+
self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype)
402+
403+
if activation == "gelu":
404+
self.activation = torch.nn.GELU()
405+
elif activation == "relu":
406+
self.activation = torch.nn.ReLU()
407+
elif activation == "silu":
408+
self.activation = torch.nn.SiLU()
409+
else:
410+
raise ValueError(f"Unsupported activation: {activation}")
411+
412+
def forward(self, x):
413+
x = self.rms_norm(x)
414+
x = self.fc(x)
415+
x = self.activation(x)
416+
return x
417+
418+
419+
class TransformerBlock(torch.nn.Module):
420+
def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16):
421+
super().__init__()
422+
self.hidden_dim = hidden_dim
423+
self.num_heads = num_heads
424+
self.head_dim = hidden_dim // num_heads
425+
426+
# Self-attention
427+
self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype)
428+
self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype)
429+
430+
# MLP
431+
self.mlp_ratio = mlp_ratio
432+
self.mlp_hidden_dim = int(hidden_dim * mlp_ratio)
433+
self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to(dtype)
434+
self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to(dtype)
435+
436+
# Layer norms
437+
self.norm1 = RMSNorm(hidden_dim, dtype=dtype)
438+
self.norm2 = RMSNorm(hidden_dim, dtype=dtype)
439+
440+
# Activation
441+
self.activation = torch.nn.GELU()
442+
443+
def forward(self, x):
444+
batch_size, seq_len, _ = x.shape
445+
446+
# Self-attention
447+
residual = x
448+
x = self.norm1(x)
449+
450+
# Reshape qkv projection for better memory layout
451+
qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim]
452+
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
453+
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim]
454+
q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim]
455+
456+
# Scaled dot-product attention with proper reshaping
457+
# Reshape for better memory layout and avoid broadcasting issues
458+
q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim)
459+
k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim)
460+
v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim)
461+
462+
# Compute attention scores
463+
attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5))
464+
attn = torch.softmax(attn, dim=-1)
465+
466+
# Apply attention to values
467+
x = attn @ v # [batch_size * num_heads, seq_len, head_dim]
468+
469+
# Reshape back to original dimensions
470+
x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim)
471+
x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim)
472+
473+
# Project back to hidden dimension
474+
x = self.proj(x)
475+
x = residual + x
476+
477+
# MLP
478+
residual = x
479+
x = self.norm2(x)
480+
x = self.mlp_fc1(x)
481+
x = self.activation(x)
482+
x = self.mlp_fc2(x)
483+
x = residual + x
484+
485+
return x
486+
487+
386488
def string_to_config(
387489
quantization: Optional[str], sparsity: Optional[str], **kwargs
388490
) -> AOBaseConfig:
@@ -576,6 +678,14 @@ def create_model_and_input(
576678
elif model_type == "ln_linear_sigmoid":
577679
model = LNLinearSigmoid(k, n, high_precision_dtype).to(device)
578680
input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype)
681+
elif model_type == "rms_norm_linear_activation":
682+
model = RMSNormLinearActivation(k, n, high_precision_dtype).to(device)
683+
input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype)
684+
elif model_type == "transformer_block":
685+
# For transformer block, k is the hidden dimension
686+
model = TransformerBlock(k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype).to(device)
687+
# Input shape for transformer is [batch_size, seq_len, hidden_dim]
688+
input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype)
579689
else:
580690
raise ValueError(f"Unknown model type: {model_type}")
581691
return model, input_data

0 commit comments

Comments
 (0)