19
19
Float8Tensor ,
20
20
merge_mm_configs ,
21
21
ScaledMMConfig ,
22
+ ScalingGranularity ,
22
23
tensor_already_casted_to_fp8 ,
23
24
to_fp8_no_autograd ,
24
25
)
@@ -36,21 +37,26 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
36
37
@staticmethod
37
38
def forward (
38
39
ctx ,
39
- tensor ,
40
+ tensor : torch . Tensor ,
40
41
mm_config : ScaledMMConfig ,
42
+ scaling_granularity : ScalingGranularity ,
41
43
):
42
44
ctx .mm_config = mm_config
45
+ ctx .scaling_granularity = scaling_granularity
43
46
return tensor
44
47
45
48
@staticmethod
46
- def backward (ctx , gradY ):
49
+ def backward (ctx , gradY : torch . Tensor ):
47
50
if tensor_already_casted_to_fp8 (gradY ):
48
- return gradY , None
49
- gradY_scale = tensor_to_scale (gradY , e5m2_dtype )
51
+ return gradY , None , None
52
+ gradY_scale = tensor_to_scale (gradY , e5m2_dtype , ctx . scaling_granularity )
50
53
fp8_tensor = to_fp8_no_autograd (
51
- gradY , gradY_scale , e5m2_dtype , mm_config = ctx .mm_config
54
+ gradY ,
55
+ gradY_scale ,
56
+ e5m2_dtype ,
57
+ mm_config = ctx .mm_config ,
52
58
)
53
- return fp8_tensor , None
59
+ return fp8_tensor , None , None
54
60
55
61
56
62
class Float8DynamicLinear (torch .nn .Linear ):
@@ -63,13 +69,19 @@ def __init__(self, **super_kwargs):
63
69
super ().__init__ (** super_kwargs )
64
70
65
71
def forward (self , input : torch .Tensor ) -> torch .Tensor :
66
- x_fp8 = cast_to_float8_e4m3_dynamic (input , self .forward_config )
72
+ x_fp8 = cast_to_float8_e4m3_dynamic (
73
+ input , self .forward_config , self .scaling_granularity
74
+ )
67
75
if isinstance (self .weight , Float8Tensor ): # cast by FSDP
68
76
w_fp8 = self .weight
69
77
else :
70
- w_fp8 = cast_to_float8_e4m3_dynamic (self .weight , self .forward_config )
78
+ w_fp8 = cast_to_float8_e4m3_dynamic (
79
+ self .weight , self .forward_config , self .scaling_granularity
80
+ )
71
81
y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
72
- y = cast_to_float8_e5m2_dynamic_bw (y , self .backward_config )
82
+ y = cast_to_float8_e5m2_dynamic_bw (
83
+ y , self .backward_config , self .scaling_granularity
84
+ )
73
85
return y
74
86
75
87
@classmethod
@@ -101,9 +113,14 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
101
113
fp8_output = False ,
102
114
pad_inner_dim = config .pad_inner_dim ,
103
115
)
116
+ # TODO: For now hardcode TensorWise scaling
117
+ new_mod .scaling_granularity = ScalingGranularity .TensorWise
118
+
104
119
if config .enable_fsdp_fp8_all_gather :
105
120
new_mod .weight = nn .Parameter (
106
- WeightWithDynamicFloat8CastTensor (mod .weight , new_mod .forward_config )
121
+ WeightWithDynamicFloat8CastTensor (
122
+ mod .weight , new_mod .forward_config , new_mod .scaling_granularity
123
+ )
107
124
)
108
125
else :
109
126
new_mod .weight = mod .weight
@@ -112,18 +129,31 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
112
129
113
130
114
131
def cast_to_float8_e4m3_dynamic (
115
- inpt_tensor : torch .Tensor , mm_config : ScaledMMConfig , reduce_amax : bool = False
132
+ inpt_tensor : torch .Tensor ,
133
+ mm_config : ScaledMMConfig ,
134
+ scaling_granularity : ScalingGranularity ,
135
+ reduce_amax : bool = False ,
116
136
) -> Float8Tensor :
117
137
if tensor_already_casted_to_fp8 (inpt_tensor ):
118
138
return inpt_tensor
119
- scale = tensor_to_scale (inpt_tensor , e4m3_dtype , reduce_amax )
120
- return Float8Tensor .to_float8 (inpt_tensor , scale , e4m3_dtype , mm_config = mm_config )
139
+ scale = tensor_to_scale (
140
+ inpt_tensor , e4m3_dtype , scaling_granularity , reduce_amax = reduce_amax
141
+ )
142
+ return Float8Tensor .to_float8 (
143
+ inpt_tensor ,
144
+ scale ,
145
+ e4m3_dtype ,
146
+ mm_config = mm_config ,
147
+ scaling_granularity = scaling_granularity ,
148
+ )
121
149
122
150
123
151
def cast_to_float8_e5m2_dynamic_bw (
124
- gradY : torch .Tensor , mm_config : ScaledMMConfig
152
+ gradY : torch .Tensor ,
153
+ mm_config : ScaledMMConfig ,
154
+ scaling_granularity : ScalingGranularity ,
125
155
) -> torch .Tensor :
126
- return NoopFwToFloat8E5M2Bw .apply (gradY , mm_config )
156
+ return NoopFwToFloat8E5M2Bw .apply (gradY , mm_config , scaling_granularity )
127
157
128
158
129
159
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
@@ -143,7 +173,12 @@ def cast_to_float8_e5m2_dynamic_bw(
143
173
144
174
class WeightWithDynamicFloat8CastTensor (torch .Tensor ):
145
175
@staticmethod
146
- def __new__ (cls , tensor : torch .Tensor , mm_config : ScaledMMConfig ):
176
+ def __new__ (
177
+ cls ,
178
+ tensor : torch .Tensor ,
179
+ mm_config : ScaledMMConfig ,
180
+ scaling_granularity : ScalingGranularity ,
181
+ ):
147
182
return torch .Tensor ._make_wrapper_subclass (
148
183
cls ,
149
184
tensor .size (),
@@ -157,24 +192,38 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
157
192
requires_grad = tensor .requires_grad ,
158
193
)
159
194
160
- def __init__ (self , tensor : torch .Tensor , mm_config : ScaledMMConfig ):
195
+ def __init__ (
196
+ self ,
197
+ tensor : torch .Tensor ,
198
+ mm_config : ScaledMMConfig ,
199
+ scaling_granularity : ScalingGranularity ,
200
+ ):
161
201
self ._tensor = tensor
162
202
self ._mm_config = mm_config
203
+ self ._scaling_granularity = scaling_granularity
163
204
164
205
@classmethod
165
206
def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
166
207
if func == torch .ops .aten .detach .default :
167
208
return WeightWithDynamicFloat8CastTensor (
168
- args [0 ]._tensor , args [0 ]._mm_config
209
+ args [0 ]._tensor , args [0 ]._mm_config , args [ 0 ]. _scaling_granularity
169
210
)
170
211
mm_config : Optional [ScaledMMConfig ] = None
212
+ scaling_granularity : Optional [ScalingGranularity ] = None
171
213
172
214
def unwrap (t ):
173
215
nonlocal mm_config
216
+ nonlocal scaling_granularity
174
217
if mm_config is None :
175
218
mm_config = t ._mm_config
176
219
else :
177
220
mm_config = merge_mm_configs (mm_config , t ._mm_config )
221
+
222
+ if scaling_granularity is None :
223
+ scaling_granularity = t ._scaling_granularity
224
+ else :
225
+ # TODO For now we assume that the scaling granularity is same across all tensors
226
+ assert scaling_granularity == t ._scaling_granularity
178
227
return t ._tensor
179
228
180
229
args , kwargs = pytree .tree_map_only (
@@ -184,23 +233,33 @@ def unwrap(t):
184
233
if func not in _ops_to_preserve_subclass :
185
234
return out
186
235
return pytree .tree_map_only (
187
- torch .Tensor , lambda x : WeightWithDynamicFloat8CastTensor (x , mm_config ), out
236
+ torch .Tensor ,
237
+ lambda x : WeightWithDynamicFloat8CastTensor (
238
+ x , mm_config , scaling_granularity
239
+ ),
240
+ out ,
188
241
)
189
242
190
243
def __tensor_flatten__ (self ):
191
- return ["_tensor" ], self ._mm_config
244
+ return ["_tensor" ], {
245
+ "_mm_config" : self ._mm_config ,
246
+ "_scaling_granularity" : self ._scaling_granularity ,
247
+ }
192
248
193
249
@staticmethod
194
250
def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
195
- mm_config = flatten_spec
196
- return WeightWithDynamicFloat8CastTensor (inner_tensors ["_tensor" ], mm_config )
251
+ mm_config = flatten_spec ["_mm_config" ]
252
+ scaling_granularity = flatten_spec ["_scaling_granularity" ]
253
+ return WeightWithDynamicFloat8CastTensor (
254
+ inner_tensors ["_tensor" ], mm_config , scaling_granularity
255
+ )
197
256
198
257
def __repr__ (self ):
199
- return f"WeightWithDynamicFloat8CastTensor(tensor={ self ._tensor } , mm_config={ self ._mm_config } )"
258
+ return f"WeightWithDynamicFloat8CastTensor(tensor={ self ._tensor } , mm_config={ self ._mm_config } , scaling_granularity= { self . _scaling_granularity } )"
200
259
201
260
def fsdp_pre_all_gather (self , mesh ):
202
261
float8_tensor = cast_to_float8_e4m3_dynamic (
203
- self ._tensor , self ._mm_config , reduce_amax = True
262
+ self ._tensor , self ._mm_config , self . _scaling_granularity , reduce_amax = True
204
263
)
205
264
return (float8_tensor ._data ,), (float8_tensor ._scale ,)
206
265
0 commit comments