@@ -63,7 +63,7 @@ def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
63
63
return True
64
64
65
65
66
- class TestFloat8Tensor ( unittest . TestCase ) :
66
+ class TestFloat8Tensor :
67
67
def test_preserves_dtype (self ) -> None :
68
68
# hp means high precision, lp means low precision
69
69
hp_dtypes = (torch .float32 , torch .float16 , torch .bfloat16 )
@@ -73,7 +73,7 @@ def test_preserves_dtype(self) -> None:
73
73
x1_s = tensor_to_scale (x1_hp , lp_dtype )
74
74
x2_lp = hp_tensor_and_scale_to_float8 (x1_hp , x1_s , lp_dtype )
75
75
x3_hp = x2_lp .to_original_precision ()
76
- self . assertTrue ( x3_hp .dtype == hp_dtype )
76
+ assert x3_hp .dtype == hp_dtype
77
77
78
78
def test_differentiable_casts (self ) -> None :
79
79
lp_dtypes = (e4m3_dtype , e5m2_dtype )
@@ -108,7 +108,7 @@ def test_index_put(self):
108
108
fp8_b = hp_tensor_and_scale_to_float8 (b , scale_a , torch .float8_e4m3fn )
109
109
fp8_b_bad = hp_tensor_and_scale_to_float8 (b , scale_b , torch .float8_e4m3fn )
110
110
111
- with self . assertRaises (AssertionError ):
111
+ with pytest . raises (AssertionError ):
112
112
b [index ] = fp8_a
113
113
fp8_b [index ] = a
114
114
fp8_b_bad [index ] = fp8_a
@@ -122,7 +122,7 @@ def test_copy_(self):
122
122
b = torch .empty (16 , dtype = torch .bfloat16 )
123
123
b .copy_ (fp8_a ) # Should work
124
124
torch .testing .assert_close (b , fp8_a .to_original_precision ())
125
- with self . assertRaises (RuntimeError ):
125
+ with pytest . raises (RuntimeError ):
126
126
fp8_a .copy_ (b ) # Should fail
127
127
128
128
fp8_b = Float8Tensor (
@@ -149,21 +149,49 @@ def test_weights_only_load(self):
149
149
buffer .seek (0 )
150
150
_ = torch .load (buffer , weights_only = True )
151
151
152
- def test_axiswise_dynamic_cast (self ):
153
- a = torch .randn (16 , 32 , dtype = torch .bfloat16 )
152
+ @pytest .mark .parametrize ("shape" , [(8 , 16 ), (4 , 8 , 16 ), (2 , 4 , 8 , 16 )])
153
+ @pytest .mark .parametrize ("dim_name" , ["first" , "last" ])
154
+ def test_axiswise_dynamic_cast (self , shape , dim_name ):
155
+ a = torch .randn (* shape , dtype = torch .bfloat16 )
156
+
157
+ if dim_name == "first" :
158
+ dim = 0
159
+ elif dim_name == "last" :
160
+ dim = len (a .shape ) - 1
161
+
154
162
linear_mm_config = LinearMMConfig ()
163
+ a_fp8 = hp_tensor_to_float8_dynamic (
164
+ a ,
165
+ e4m3_dtype ,
166
+ linear_mm_config ,
167
+ scaling_granularity = ScalingGranularity .AXISWISE ,
168
+ axiswise_dim = dim ,
169
+ )
170
+ a_dq = a_fp8 .to_original_precision ()
171
+ sqnr = compute_error (a , a_dq )
172
+ assert sqnr >= 25.0
173
+
174
+ # TODO(next) make this work
175
+ def test_axiswise_reshape (self ):
176
+ a = torch .randn (3 , 5 , 7 , dtype = torch .bfloat16 , device = "cuda" )
177
+ linear_mm_config = LinearMMConfig ()
178
+
155
179
a_fp8 = hp_tensor_to_float8_dynamic (
156
180
a ,
157
181
e4m3_dtype ,
158
182
linear_mm_config ,
159
183
scaling_granularity = ScalingGranularity .AXISWISE ,
160
184
axiswise_dim = 0 ,
161
185
)
162
- # print(a_fp8)
163
- # print(a_fp8.to_original_precision())
164
- # print(a_fp8.t())
165
- b = a_fp8 .t ()
166
- # TODO check numerical accuracy
186
+ # a_fp8._data.shape is (3, 5, 7)
187
+ # a_fp8._scale.shape is (1, 5, 7)
188
+ print (a_fp8 ._scale .shape )
189
+
190
+ # reshape to (3, 5 * 7)
191
+ # a_fp8._scale.shape should be (1, 5 * 7)
192
+ a_fp8_r = a_fp8 .reshape (3 , - 1 )
193
+ print (a_fp8_r ._scale .shape )
194
+
167
195
168
196
def test_axiswise_gemm (self ):
169
197
a = torch .randn (16 , 32 , dtype = torch .bfloat16 , device = "cuda" )
@@ -177,18 +205,21 @@ def test_axiswise_gemm(self):
177
205
linear_mm_config ,
178
206
gemm_input_role = GemmInputRole .INPUT ,
179
207
scaling_granularity = ScalingGranularity .AXISWISE ,
180
- axiswise_dim = 0 ,
208
+ axiswise_dim = 1 ,
181
209
)
182
210
b_fp8 = hp_tensor_to_float8_dynamic (
183
211
b ,
184
212
e4m3_dtype ,
185
213
linear_mm_config ,
186
214
gemm_input_role = GemmInputRole .WEIGHT ,
187
215
scaling_granularity = ScalingGranularity .AXISWISE ,
188
- axiswise_dim = 0 ,
216
+ axiswise_dim = 1 ,
189
217
)
190
- c = torch .mm (a_fp8 , b_fp8 .t ())
191
- print (c )
218
+ c_fp8_compute = torch .mm (a_fp8 , b_fp8 .t ())
219
+ print (c_fp8_compute )
220
+ c_ref = torch .mm (a , b .t ())
221
+ sqnr = compute_error (c_ref , c_fp8_compute )
222
+ print ('sqnr' , sqnr )
192
223
# TODO check numerical accuracy
193
224
194
225
0 commit comments