17
17
Float8DynamicActivationFloat8SemiSparseWeightConfig ,
18
18
Int4WeightOnlyConfig ,
19
19
LNLinearSigmoid ,
20
+ RMSNorm ,
21
+ RMSNormLinearActivation ,
20
22
SemiSparseWeightConfig ,
21
23
ToyLinearModel ,
24
+ TransformerBlock ,
22
25
clean_caches ,
23
26
create_model_and_input ,
24
27
generate_results_csv ,
@@ -162,6 +165,61 @@ def test_ln_linear_sigmoid(self):
162
165
torch .all ((out >= 0 ) & (out <= 1 ))
163
166
) # Check sigmoid output range
164
167
168
+ def test_rms_norm (self ):
169
+ # Test RMSNorm
170
+ rms_norm = RMSNorm (dim = 64 )
171
+ x = torch .randn (16 , 64 )
172
+ out = rms_norm (x )
173
+ self .assertEqual (out .shape , (16 , 64 ))
174
+
175
+ # Test with different eps
176
+ rms_norm = RMSNorm (dim = 64 , eps = 1e-5 )
177
+ out = rms_norm (x )
178
+ self .assertEqual (out .shape , (16 , 64 ))
179
+
180
+ def test_rms_norm_linear_activation (self ):
181
+ # Test with default GELU activation
182
+ model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 )
183
+ x = torch .randn (16 , 64 )
184
+ out = model (x )
185
+ self .assertEqual (out .shape , (16 , 32 ))
186
+ self .assertEqual (out .dtype , torch .float32 )
187
+
188
+ # Test with ReLU activation
189
+ model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "relu" )
190
+ out = model (x )
191
+ self .assertEqual (out .shape , (16 , 32 ))
192
+ self .assertTrue (torch .all (out >= 0 )) # Check ReLU output range
193
+
194
+ # Test with SiLU activation
195
+ model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "silu" )
196
+ out = model (x )
197
+ self .assertEqual (out .shape , (16 , 32 ))
198
+
199
+ # Test with invalid activation
200
+ with self .assertRaises (ValueError ):
201
+ RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "invalid" )
202
+
203
+ def test_transformer_block (self ):
204
+ # Test with default parameters
205
+ model = TransformerBlock (hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32 )
206
+ x = torch .randn (16 , 16 , 64 ) # [batch_size, seq_len, hidden_dim]
207
+ out = model (x )
208
+ self .assertEqual (out .shape , (16 , 16 , 64 ))
209
+ self .assertEqual (out .dtype , torch .float32 )
210
+
211
+ # Test with different parameters
212
+ model = TransformerBlock (hidden_dim = 128 , num_heads = 4 , mlp_ratio = 2 , dtype = torch .float32 )
213
+ x = torch .randn (8 , 32 , 128 )
214
+ out = model (x )
215
+ self .assertEqual (out .shape , (8 , 32 , 128 ))
216
+
217
+ # Test with different head dimensions
218
+ model = TransformerBlock (hidden_dim = 96 , num_heads = 6 , mlp_ratio = 3 , dtype = torch .float32 )
219
+ x = torch .randn (4 , 8 , 96 )
220
+ out = model (x )
221
+ self .assertEqual (out .shape , (4 , 8 , 96 ))
222
+
165
223
def test_create_model_and_input (self ):
166
224
m , k , n = 16 , 64 , 32
167
225
model , input_data = create_model_and_input (
@@ -186,6 +244,63 @@ def test_create_model_and_input(self):
186
244
self .assertIsInstance (model , LNLinearSigmoid )
187
245
self .assertEqual (input_data .shape , (m , k ))
188
246
247
+ # Test RMSNormLinearActivation
248
+ model , input_data = create_model_and_input (
249
+ model_type = "rms_norm_linear_activation" ,
250
+ m = m ,
251
+ k = k ,
252
+ n = n ,
253
+ high_precision_dtype = torch .float32 ,
254
+ device = "cpu" ,
255
+ )
256
+ self .assertIsInstance (model , RMSNormLinearActivation )
257
+ self .assertEqual (input_data .shape , (m , k ))
258
+
259
+ # Test TransformerBlock
260
+ model , input_data = create_model_and_input (
261
+ model_type = "transformer_block" ,
262
+ m = m ,
263
+ k = k ,
264
+ n = n , # n is not used for transformer_block
265
+ high_precision_dtype = torch .float32 ,
266
+ device = "cpu" ,
267
+ )
268
+ self .assertIsInstance (model , TransformerBlock )
269
+ self .assertEqual (input_data .shape , (m , 16 , k )) # [batch_size, seq_len, hidden_dim]
270
+
271
+ def test_quantization_on_models (self ):
272
+ # Test quantization on RMSNormLinearActivation
273
+ model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 )
274
+ x = torch .randn (16 , 64 )
275
+
276
+ # Test with Int8WeightOnlyConfig
277
+ config = string_to_config (quantization = "int8wo" , sparsity = None )
278
+ if config is not None :
279
+ # Skip quantization test if torchao.quantization.quantize is not available
280
+ try :
281
+ from torchao .quantization import quantize
282
+ quantized_model = quantize (model , config )
283
+ out = quantized_model (x )
284
+ self .assertEqual (out .shape , (16 , 32 ))
285
+ except ImportError :
286
+ print ("Skipping quantization test: torchao.quantization.quantize not available" )
287
+
288
+ # Test quantization on TransformerBlock
289
+ model = TransformerBlock (hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32 )
290
+ x = torch .randn (16 , 16 , 64 )
291
+
292
+ # Test with Int8WeightOnlyConfig
293
+ config = string_to_config (quantization = "int8wo" , sparsity = None )
294
+ if config is not None :
295
+ # Skip quantization test if torchao.quantization.quantize is not available
296
+ try :
297
+ from torchao .quantization import quantize
298
+ quantized_model = quantize (model , config )
299
+ out = quantized_model (x )
300
+ self .assertEqual (out .shape , (16 , 16 , 64 ))
301
+ except ImportError :
302
+ print ("Skipping quantization test: torchao.quantization.quantize not available" )
303
+
189
304
def test_generate_results_csv (self ):
190
305
results = [
191
306
BenchmarkResult (
0 commit comments