19
19
FLOAT8_OPS_TABLE : Dict [Any , Any ] = {}
20
20
21
21
22
+ def _assert_tensorwise_scale (aten_op , scale ):
23
+ assert (
24
+ # TODO(future PR): figure out why tensorwise scaling can have
25
+ # both rank 0 and rank 1
26
+ len (scale .shape )
27
+ in (0 , 1 )
28
+ ), f"{ aten_op } with axiswise scaling is not supported yet"
29
+
30
+
22
31
def implements (aten_ops ):
23
32
"""Register aten ops to the float8 op table"""
24
33
@@ -32,18 +41,16 @@ def decorator(func):
32
41
33
42
@implements (
34
43
[
35
- aten .view .default ,
36
44
aten ._unsafe_view .default ,
37
- aten .t .default ,
38
45
aten .as_strided .default ,
39
46
aten .clone .default ,
40
47
aten .detach .default ,
41
48
aten .slice .Tensor ,
42
- aten .transpose .int ,
43
49
aten .fill_ .Scalar ,
44
50
]
45
51
)
46
52
def float8_desugar_op (aten_op , args , kwargs = None ):
53
+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
47
54
new_data = aten_op (args [0 ]._data , * args [1 :], ** kwargs )
48
55
return Float8Tensor (
49
56
new_data ,
@@ -54,8 +61,61 @@ def float8_desugar_op(aten_op, args, kwargs=None):
54
61
)
55
62
56
63
64
+ @implements (
65
+ [
66
+ aten .t .default ,
67
+ aten .transpose .int ,
68
+ ]
69
+ )
70
+ def float8_desugar_data_and_scale (aten_op , args , kwargs = None ):
71
+ new_data = aten_op (args [0 ]._data , * args [1 :], ** kwargs )
72
+ new_scale = aten_op (args [0 ]._scale , * args [1 :], ** kwargs )
73
+ return Float8Tensor (
74
+ new_data ,
75
+ new_scale ,
76
+ args [0 ]._orig_dtype ,
77
+ args [0 ]._linear_mm_config ,
78
+ args [0 ]._gemm_input_role ,
79
+ )
80
+
81
+
82
+ @implements ([aten .view .default ])
83
+ def float8_view (aten_op , args , kwargs = None ):
84
+ if len (args [0 ]._scale .shape ) < 2 :
85
+ # tensorwise scaling
86
+ return float8_desugar_op (aten_op , args , kwargs )
87
+
88
+ t , new_shape = args [0 ], args [1 ]
89
+ # for now, only support reshaping to [-1, dim] or [dim, -1]
90
+ if len (new_shape ) == 2 :
91
+ if new_shape == [t .shape [0 ], - 1 ] and t ._scale .shape [0 ] == 1 :
92
+ new_data = aten_op (t ._data , new_shape , ** kwargs )
93
+ new_scale = aten_op (t ._scale , [1 , - 1 ], ** kwargs )
94
+ return Float8Tensor (
95
+ new_data ,
96
+ new_scale ,
97
+ t ._orig_dtype ,
98
+ t ._linear_mm_config ,
99
+ t ._gemm_input_role ,
100
+ )
101
+ elif new_shape == [- 1 , t .shape [- 1 ]] and t ._scale .shape [- 1 ] == 1 :
102
+ new_data = aten_op (t ._data , new_shape , ** kwargs )
103
+ new_scale = aten_op (t ._scale , [- 1 , 1 ], ** kwargs )
104
+ return Float8Tensor (
105
+ new_data ,
106
+ new_scale ,
107
+ t ._orig_dtype ,
108
+ t ._linear_mm_config ,
109
+ t ._gemm_input_role ,
110
+ )
111
+ raise AssertionError (
112
+ f"{ aten_op } with axiswise scaling and t.shape { t .shape } t._scale.shape { t ._scale .shape } new_shape { new_shape } is not supported yet."
113
+ )
114
+
115
+
57
116
@implements ([aten .split .Tensor ])
58
117
def float8_split (aten_op , args , kwargs = None ):
118
+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
59
119
new_data_tensors = aten_op (args [0 ]._data , * args [1 :], ** kwargs )
60
120
61
121
def make_float8 (data ):
@@ -101,6 +161,7 @@ def float8_cat(aten_op, args, kwargs=None):
101
161
assert (
102
162
chunk ._gemm_input_role is gemm_input_role
103
163
), "Expecting all chunks to have the same gemm_input_role as a result of a split"
164
+ _assert_tensorwise_scale (aten_op , chunk ._scale )
104
165
chunk_data .append (chunk ._data .view (torch .uint8 ))
105
166
106
167
new_data = aten_op (chunk_data , * args [1 :], ** kwargs )
@@ -117,6 +178,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None):
117
178
"addmm" -> out
118
179
"hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"
119
180
"""
181
+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
120
182
121
183
def unwrap (x ):
122
184
if isinstance (x , Float8Tensor ):
@@ -229,6 +291,7 @@ def float8_addmm(aten_op, args, kwargs=None):
229
291
230
292
@implements ([aten .is_same_size .default ])
231
293
def float8_is_same_size (aten_op , args , kwargs = None ):
294
+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
232
295
return args [0 ].shape == args [1 ].shape
233
296
234
297
@@ -238,6 +301,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
238
301
when the input is a Float8Tensor, presenting as a fp32
239
302
tensor.
240
303
"""
304
+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
241
305
assert isinstance (args [0 ], Float8Tensor )
242
306
assert (
243
307
len (kwargs ) == 1 and "dtype" in kwargs
@@ -265,6 +329,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
265
329
"""
266
330
override funcol with FP8 handling
267
331
"""
332
+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
268
333
fp8_input = args [0 ]
269
334
assert isinstance (
270
335
fp8_input , Float8Tensor
@@ -284,6 +349,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
284
349
285
350
@implements ([c10d_functional .wait_tensor .default , _c10d_functional .wait_tensor .default ])
286
351
def wait_tensor_fp8 (aten_op , args , kwargs = None ):
352
+ _assert_tensorwise_scale (aten_op , args [0 ]._scale )
287
353
fp8_input = args [0 ]
288
354
assert isinstance (fp8_input , Float8Tensor )
289
355
@@ -304,6 +370,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
304
370
fp8_values = args [2 ]
305
371
assert isinstance (fp8_self , Float8Tensor )
306
372
assert isinstance (fp8_values , Float8Tensor )
373
+ _assert_tensorwise_scale (fp8_self , args [0 ]._scale )
307
374
assert fp8_self ._scale == fp8_values ._scale
308
375
assert fp8_self .dtype == fp8_values .dtype
309
376
assert fp8_self ._orig_dtype == fp8_values ._orig_dtype
@@ -334,8 +401,10 @@ def copy_fp8(aten_op, args, kwargs=None):
334
401
335
402
if not isinstance (self , Float8Tensor ) and isinstance (src , Float8Tensor ):
336
403
src_hp = src .to_original_precision ()
404
+ _assert_tensorwise_scale (aten_op , src ._scale )
337
405
return aten_op (self , src_hp , * args [2 :], ** kwargs )
338
406
elif isinstance (self , Float8Tensor ) and isinstance (src , Float8Tensor ):
407
+ _assert_tensorwise_scale (aten_op , src ._scale )
339
408
assert (
340
409
self ._orig_dtype == src ._orig_dtype
341
410
), "Expecting both Float8Tensors to be of the same dtype"
0 commit comments