22
22
pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
23
23
24
24
25
- from torchao .float8 .config import CastConfig , Float8LinearConfig , ScalingType
25
+ from torchao .float8 .config import (
26
+ CastConfig ,
27
+ Float8LinearConfig ,
28
+ ScalingGranularity ,
29
+ ScalingType ,
30
+ )
26
31
from torchao .float8 .float8_linear import Float8Linear
27
32
from torchao .float8 .float8_linear_utils import (
28
33
convert_to_float8_training ,
29
34
linear_requires_sync ,
30
35
sync_float8_amax_and_scale_history ,
31
36
)
32
37
from torchao .float8 .float8_python_api import addmm_float8_unwrapped
38
+ from torchao .float8 .float8_scaling_utils import hp_tensor_to_float8_dynamic
33
39
from torchao .float8 .float8_tensor import (
34
40
Float8Tensor ,
35
41
GemmInputRole ,
51
57
52
58
53
59
is_cuda_8_9 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
60
+ is_cuda_9_0 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 )
54
61
55
62
def bitwise_identical (a : Float8Tensor , b : Float8Tensor ) -> bool :
56
63
assert torch .all (a ._scale == b ._scale ).item (), "scales are not identical"
57
64
assert torch .all (a ._data == b ._data ).item (), "data is not identical"
58
65
return True
59
66
60
67
61
- class TestFloat8Tensor ( unittest . TestCase ) :
68
+ class TestFloat8Tensor :
62
69
def test_preserves_dtype (self ) -> None :
63
70
# hp means high precision, lp means low precision
64
71
hp_dtypes = (torch .float32 , torch .float16 , torch .bfloat16 )
@@ -68,7 +75,7 @@ def test_preserves_dtype(self) -> None:
68
75
x1_s = tensor_to_scale (x1_hp , lp_dtype )
69
76
x2_lp = hp_tensor_and_scale_to_float8 (x1_hp , x1_s , lp_dtype )
70
77
x3_hp = x2_lp .to_original_precision ()
71
- self . assertTrue ( x3_hp .dtype == hp_dtype )
78
+ assert x3_hp .dtype == hp_dtype
72
79
73
80
def test_differentiable_casts (self ) -> None :
74
81
lp_dtypes = (e4m3_dtype , e5m2_dtype )
@@ -103,7 +110,7 @@ def test_index_put(self):
103
110
fp8_b = hp_tensor_and_scale_to_float8 (b , scale_a , torch .float8_e4m3fn )
104
111
fp8_b_bad = hp_tensor_and_scale_to_float8 (b , scale_b , torch .float8_e4m3fn )
105
112
106
- with self . assertRaises (AssertionError ):
113
+ with pytest . raises (AssertionError ):
107
114
b [index ] = fp8_a
108
115
fp8_b [index ] = a
109
116
fp8_b_bad [index ] = fp8_a
@@ -117,7 +124,7 @@ def test_copy_(self):
117
124
b = torch .empty (16 , dtype = torch .bfloat16 )
118
125
b .copy_ (fp8_a ) # Should work
119
126
torch .testing .assert_close (b , fp8_a .to_original_precision ())
120
- with self . assertRaises (RuntimeError ):
127
+ with pytest . raises (RuntimeError ):
121
128
fp8_a .copy_ (b ) # Should fail
122
129
123
130
fp8_b = Float8Tensor (
@@ -129,6 +136,105 @@ def test_copy_(self):
129
136
fp8_b .copy_ (fp8_a )
130
137
torch .testing .assert_close (fp8_a ._data , fp8_b ._data )
131
138
139
+ @pytest .mark .parametrize ("shape" , [(8 , 16 ), (4 , 8 , 16 ), (2 , 4 , 8 , 16 )])
140
+ @pytest .mark .parametrize ("axiswise_dim" , [0 , - 1 ])
141
+ def test_axiswise_dynamic_cast (self , shape , axiswise_dim ):
142
+ a = torch .randn (* shape , dtype = torch .bfloat16 )
143
+ linear_mm_config = LinearMMConfig ()
144
+ a_fp8 = hp_tensor_to_float8_dynamic (
145
+ a ,
146
+ e4m3_dtype ,
147
+ linear_mm_config ,
148
+ scaling_granularity = ScalingGranularity .AXISWISE ,
149
+ axiswise_dim = axiswise_dim ,
150
+ )
151
+ a_dq = a_fp8 .to_original_precision ()
152
+ sqnr = compute_error (a , a_dq )
153
+ assert sqnr >= 25.0
154
+
155
+ def test_axiswise_reshape (self ):
156
+ a = torch .randn (3 , 5 , 7 , dtype = torch .bfloat16 )
157
+ linear_mm_config = LinearMMConfig ()
158
+
159
+ # if we scale across dim0, we can only reshape to [3, -1]
160
+ a_fp8_d0 = hp_tensor_to_float8_dynamic (
161
+ a ,
162
+ e4m3_dtype ,
163
+ linear_mm_config ,
164
+ scaling_granularity = ScalingGranularity .AXISWISE ,
165
+ axiswise_dim = 0 ,
166
+ )
167
+ assert list (a_fp8_d0 ._data .shape ) == [3 , 5 , 7 ]
168
+ assert list (a_fp8_d0 ._scale .shape ) == [1 , 5 , 7 ]
169
+
170
+ a_fp8_d0_r = a_fp8_d0 .reshape (3 , - 1 )
171
+ assert list (a_fp8_d0_r .shape ) == [3 , 5 * 7 ]
172
+ assert list (a_fp8_d0_r ._scale .shape ) == [1 , 5 * 7 ]
173
+ # verify numerics did not change
174
+ assert torch .allclose (
175
+ a_fp8_d0 .to_original_precision (),
176
+ a_fp8_d0_r .to_original_precision ().reshape (3 , 5 , 7 ),
177
+ atol = 0 ,
178
+ rtol = 0 ,
179
+ )
180
+ with pytest .raises (RuntimeError ):
181
+ a_fp8_d0_r2 = a_fp8_d0 .reshape (- 1 , 7 )
182
+
183
+ # if we scale across dim2, we can only reshape to [-1, 7]
184
+ a_fp8_d2 = hp_tensor_to_float8_dynamic (
185
+ a ,
186
+ e4m3_dtype ,
187
+ linear_mm_config ,
188
+ scaling_granularity = ScalingGranularity .AXISWISE ,
189
+ axiswise_dim = - 1 ,
190
+ )
191
+ assert list (a_fp8_d2 ._data .shape ) == [3 , 5 , 7 ]
192
+ assert list (a_fp8_d2 ._scale .shape ) == [3 , 5 , 1 ]
193
+
194
+ a_fp8_d2_r = a_fp8_d2 .reshape (- 1 , 7 )
195
+ assert list (a_fp8_d2_r .shape ) == [3 * 5 , 7 ]
196
+ assert list (a_fp8_d2_r ._scale .shape ) == [3 * 5 , 1 ]
197
+ # verify numerics did not change
198
+ assert torch .allclose (
199
+ a_fp8_d2 .to_original_precision (),
200
+ a_fp8_d2_r .to_original_precision ().reshape (3 , 5 , 7 ),
201
+ atol = 0 ,
202
+ rtol = 0 ,
203
+ )
204
+ with pytest .raises (RuntimeError ):
205
+ a_fp8_d2_r2 = a_fp8_d2 .reshape (3 , - 1 )
206
+
207
+ @pytest .mark .parametrize ("a_shape" , [(16 , 32 ), (2 , 16 , 32 ), (1 , 2 , 16 , 32 )])
208
+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
209
+ @unittest .skipIf (not is_cuda_9_0 , "Requires CUDA capability >= 9.0" )
210
+ def test_axiswise_gemm (self , a_shape ):
211
+ a = torch .randn (* a_shape , dtype = torch .bfloat16 , device = "cuda" )
212
+ b = torch .randn (64 , 32 , dtype = torch .bfloat16 , device = "cuda" )
213
+
214
+ linear_mm_config = LinearMMConfig ()
215
+
216
+ a_fp8 = hp_tensor_to_float8_dynamic (
217
+ a ,
218
+ e4m3_dtype ,
219
+ linear_mm_config ,
220
+ gemm_input_role = GemmInputRole .INPUT ,
221
+ scaling_granularity = ScalingGranularity .AXISWISE ,
222
+ axiswise_dim = - 1 ,
223
+ )
224
+ a_fp8 = a_fp8 .reshape (- 1 , a_shape [- 1 ])
225
+ b_fp8 = hp_tensor_to_float8_dynamic (
226
+ b ,
227
+ e4m3_dtype ,
228
+ linear_mm_config ,
229
+ gemm_input_role = GemmInputRole .WEIGHT ,
230
+ scaling_granularity = ScalingGranularity .AXISWISE ,
231
+ axiswise_dim = - 1 , # will be transposed
232
+ )
233
+ c_fp8_compute = torch .mm (a_fp8 , b_fp8 .t ())
234
+ a = a .reshape (- 1 , a_shape [- 1 ])
235
+ c_ref = torch .mm (a , b .t ())
236
+ sqnr = compute_error (c_ref , c_fp8_compute )
237
+ assert sqnr >= 25.0
132
238
133
239
134
240
class TestFloat8Linear :
0 commit comments