Skip to content

Commit bcdb20c

Browse files
committed
Updates
1 parent 6e88306 commit bcdb20c

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

benchmarks/microbenchmarks/test/benchmark_config.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ quantization_config_recipe_names:
88
- "float8wo"
99
# sparsity_config_recipe_names:
1010
# Will run a baseline inference for model by default, without sparsity for comparison
11-
- "semi-sparse"
12-
- "block"
11+
# - "semi-sparse"
12+
# - "block"
1313
output_dir: "benchmarks/microbenchmarks/results"
1414
model_params:
1515
- name: "small_bf16_linear"

torchao/testing/model_architectures.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ def forward(self, x):
2121

2222

2323
class LNLinearActivationModel(nn.Module):
24-
def __init__(
25-
self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="sigmoid", device=None
26-
):
24+
def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="sigmoid"):
2725
super().__init__()
2826

2927
activation = activation.lower()
@@ -41,7 +39,7 @@ def __init__(
4139
raise ValueError(f"Unsupported activation: {activation}")
4240

4341
self.ln = nn.LayerNorm(fc_dim1, elementwise_affine=False)
44-
self.fc = nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype=dtype, device=device)
42+
self.fc = nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype=dtype)
4543
self.activation = activation_map[activation]
4644

4745
def forward(self, x):
@@ -50,6 +48,20 @@ def forward(self, x):
5048
return self.activation(x)
5149

5250

51+
class RMSNorm(nn.Module):
52+
def __init__(self, dim: int, eps: float = 1e-5):
53+
super().__init__()
54+
self.eps = eps
55+
self.weight = nn.Parameter(torch.ones(dim))
56+
57+
def _norm(self, x):
58+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
59+
60+
def forward(self, x: torch.Tensor) -> torch.Tensor:
61+
output = self._norm(x.float()).type_as(x)
62+
return output * self.weight
63+
64+
5365
class TransformerBlock(torch.nn.Module):
5466
def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16):
5567
super().__init__()
@@ -72,8 +84,8 @@ def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16):
7284
)
7385

7486
# Layer norms
75-
self.norm1 = nn.RMSNorm(hidden_dim, dtype=dtype)
76-
self.norm2 = nn.RMSNorm(hidden_dim, dtype=dtype)
87+
self.norm1 = RMSNorm(hidden_dim).to(dtype)
88+
self.norm2 = RMSNorm(hidden_dim).to(dtype)
7789

7890
# Activation
7991
self.activation = torch.nn.GELU()

0 commit comments

Comments
 (0)