Skip to content

Commit a97259e

Browse files
committed
mx formats: create MXLinearConfig
Summary: Creating a config to make it easier to enable more configuration options. This is important as we enable more functionality from pytorch/pytorch#146414 in torchao. Note that future PRs will flesh out the config, to keep PR size small. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d6b63187f87dee1467ba2a6377863023df7d9f6f ghstack-comment-id: 2648821215 Pull Request resolved: #1688
1 parent e7914e9 commit a97259e

File tree

5 files changed

+88
-21
lines changed

5 files changed

+88
-21
lines changed

test/prototype/mx_formats/test_mx_tensor.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pytest
88
import torch
99

10-
from torchao.prototype.mx_formats import config
1110
from torchao.prototype.mx_formats.constants import (
1211
DTYPE_FP4,
1312
DTYPE_FP6_E2M3,
@@ -139,8 +138,14 @@ def test_exponent_nan_out(elem_dtype):
139138
else:
140139
raise AssertionError("unsupported")
141140
block_size = 2
141+
use_fp4_custom_triton_dequant_kernel = False
142142
tensor_mx = MXTensor(
143-
scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float
143+
scale_e8m0_bits,
144+
data_bits,
145+
elem_dtype,
146+
block_size,
147+
torch.float,
148+
use_fp4_custom_triton_dequant_kernel,
144149
)
145150
tensor_hp = tensor_mx.to_dtype(torch.float)
146151
assert torch.all(torch.isnan(tensor_hp[0:1]))
@@ -188,15 +193,16 @@ def test_transpose(elem_dtype, fp4_triton):
188193
M, K = 128, 256
189194
block_size = 32
190195
tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
191-
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
192-
config.use_fp4_custom_triton_dequant_kernel = fp4_triton
196+
tensor_mx = MXTensor.to_mx(
197+
tensor_hp,
198+
elem_dtype,
199+
block_size,
200+
use_fp4_custom_triton_dequant_kernel=fp4_triton,
201+
)
193202
tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t()
194-
config.use_fp4_custom_triton_dequant_kernel = False
195203

196204
tensor_mx_t = tensor_mx.t()
197-
config.use_fp4_custom_triton_dequant_kernel = fp4_triton
198205
tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype)
199-
config.use_fp4_custom_triton_dequant_kernel = False
200206

201207
assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape
202208
torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0)
+13-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,13 @@
1-
# If True, uses a custom triton kernel for fp4 dequantize
2-
use_fp4_custom_triton_dequant_kernel = False
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
9+
10+
@dataclass
11+
class MXLinearConfig:
12+
# If True, uses a custom triton kernel for fp4 dequantize
13+
use_fp4_custom_triton_dequant_kernel: bool = False

torchao/prototype/mx_formats/mx_linear.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
Defines the prototype UX for converting a model to use mx weights
99
"""
1010

11-
from typing import Any
11+
from typing import Any, Optional
1212

1313
import torch
1414
import torch.nn.functional as F
1515

16+
from torchao.prototype.mx_formats.config import MXLinearConfig
1617
from torchao.prototype.mx_formats.mx_tensor import MXTensor
1718

1819

@@ -110,13 +111,19 @@ def from_float(
110111
elem_dtype_weight_override=None,
111112
elem_dtype_grad_output_override=None,
112113
*,
114+
# TODO(next PR): move elem_dtype* and block size into config
115+
config: MXLinearConfig = None,
113116
block_size=32,
114117
):
115118
mod.__class__ = MXLinear
116119
mod.in_elem_dtype = elem_dtype
117120
mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype
118121
mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype
119122
mod.block_size = block_size
123+
# TODO(next PR): fix this
124+
if config is None:
125+
config = MXLinearConfig()
126+
mod.config = config
120127
return mod
121128

122129
def forward(self, x):
@@ -151,7 +158,9 @@ class MXInferenceLinear(torch.nn.Linear):
151158

152159
@classmethod
153160
@torch.no_grad()
154-
def from_float(cls, mod, elem_dtype, block_size):
161+
def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig):
162+
# TODO(next PR): move elem_dtype and block_size into config
163+
155164
with torch.device("meta"):
156165
super_kwargs = {
157166
"in_features": mod.in_features,
@@ -166,6 +175,7 @@ def from_float(cls, mod, elem_dtype, block_size):
166175
)
167176
new_mod.bias = mod.bias
168177
new_mod.elem_dtype = elem_dtype
178+
new_mod.config = config
169179
return new_mod
170180

171181
@torch.no_grad()
@@ -207,6 +217,8 @@ def swap_linear_with_mx_linear(
207217
elem_dtype_weight_override=None,
208218
elem_dtype_grad_output_override=None,
209219
*,
220+
# TODO(next PR): move elem_dtype* and block_size into config
221+
config: Optional[MXLinearConfig] = None,
210222
block_size=32,
211223
filter_fn=None,
212224
):
@@ -225,6 +237,7 @@ def __fn(mod, fqn):
225237
elem_dtype,
226238
elem_dtype_weight_override,
227239
elem_dtype_grad_output_override,
240+
config=config,
228241
block_size=block_size,
229242
),
230243
combined_filter_fn,
@@ -236,6 +249,7 @@ def swap_linear_with_mx_inference_linear(
236249
elem_dtype,
237250
block_size,
238251
filter_fn=None,
252+
config: Optional[MXLinearConfig] = None,
239253
):
240254
if filter_fn is None:
241255
combined_filter_fn = _is_linear
@@ -247,6 +261,8 @@ def __fn(mod, fqn):
247261
combined_filter_fn = __fn
248262
replace_with_custom_fn_if_matches_filter(
249263
model,
250-
lambda mod: MXInferenceLinear.from_float(mod, elem_dtype, block_size),
264+
lambda mod: MXInferenceLinear.from_float(
265+
mod, elem_dtype, block_size, config=config
266+
),
251267
combined_filter_fn,
252268
)

torchao/prototype/mx_formats/mx_ops.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def mx_desugar_op(aten_op, args, kwargs=None):
5454
old._elem_dtype,
5555
old._block_size,
5656
old._orig_dtype,
57+
old._use_fp4_custom_triton_dequant_kernel,
5758
)
5859
return new
5960

@@ -82,6 +83,7 @@ def mx_t(aten_op, args, kwargs=None):
8283
old._elem_dtype,
8384
old._block_size,
8485
old._orig_dtype,
86+
old._use_fp4_custom_triton_dequant_kernel,
8587
)
8688
return new
8789

@@ -120,6 +122,7 @@ def mx_view_op(aten_op, args, kwargs=None):
120122
args[0]._elem_dtype,
121123
args[0]._block_size,
122124
args[0]._orig_dtype,
125+
args[0]._use_fp4_custom_triton_dequant_kernel,
123126
)
124127

125128

@@ -130,7 +133,6 @@ def autocast_to_copy(aten_op, args, kwargs=None):
130133
tensor.
131134
"""
132135
assert isinstance(args[0], MXTensor)
133-
# print('before', args[0], args[0].dtype, args[0]._orig_dtype)
134136
assert (
135137
len(kwargs) == 1 and "dtype" in kwargs
136138
), "Only support dtype kwarg for autocast"
@@ -144,6 +146,6 @@ def autocast_to_copy(aten_op, args, kwargs=None):
144146
args[0]._elem_dtype,
145147
args[0]._block_size,
146148
kwargs["dtype"],
149+
args[0]._use_fp4_custom_triton_dequant_kernel,
147150
)
148-
# print('after', res, res.dtype, res._orig_dtype)
149151
return res

torchao/prototype/mx_formats/mx_tensor.py

+39-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import torch
2323

24-
import torchao.prototype.mx_formats.config as config
2524
from torchao.prototype.mx_formats.constants import (
2625
BLOCK_SIZE_DEFAULT,
2726
DTYPE_FP4,
@@ -239,7 +238,14 @@ def get_fp_scale(scale_e8m0):
239238
return s_fp
240239

241240

242-
def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype):
241+
def to_dtype(
242+
data_lp,
243+
scale_e8m0,
244+
elem_dtype,
245+
block_size,
246+
target_dtype,
247+
use_fp4_custom_triton_dequant_kernel,
248+
):
243249
orig_shape = data_lp.shape
244250
is_transposed = not data_lp.is_contiguous()
245251
# if the underlying data is transposed, convert to row major before
@@ -258,7 +264,7 @@ def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype):
258264
data_hp = f6_e3m2_unpacked_to_f32(data_lp)
259265
data_hp = data_hp.to(target_dtype)
260266
elif elem_dtype == DTYPE_FP4:
261-
if config.use_fp4_custom_triton_dequant_kernel:
267+
if use_fp4_custom_triton_dequant_kernel:
262268
data_hp_rescaled = triton_f4_to_scaled_bf16(
263269
data_lp,
264270
scale_e8m0,
@@ -318,17 +324,29 @@ class ToMXConstrFunc(torch.autograd.Function):
318324
"""
319325

320326
@staticmethod
321-
def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode):
327+
def forward(
328+
ctx,
329+
data_hp,
330+
elem_dtype,
331+
block_size,
332+
scaling_mode,
333+
use_fp4_custom_triton_dequant_kernel,
334+
):
322335
scale_e8m0_biased, data_lp = to_mx(
323336
data_hp, elem_dtype, block_size, scaling_mode
324337
)
325338
return MXTensor(
326-
scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype
339+
scale_e8m0_biased,
340+
data_lp,
341+
elem_dtype,
342+
block_size,
343+
data_hp.dtype,
344+
use_fp4_custom_triton_dequant_kernel,
327345
)
328346

329347
@staticmethod
330348
def backward(ctx, g):
331-
return g, None, None, None
349+
return g, None, None, None, None
332350

333351

334352
@torch._dynamo.allow_in_graph
@@ -345,6 +363,7 @@ def forward(ctx, tensor_lp, target_dtype):
345363
tensor_lp._elem_dtype,
346364
tensor_lp._block_size,
347365
target_dtype,
366+
tensor_lp._use_fp4_custom_triton_dequant_kernel,
348367
)
349368

350369
@staticmethod
@@ -360,6 +379,7 @@ def __new__(
360379
elem_dtype,
361380
block_size,
362381
orig_dtype,
382+
use_fp4_custom_triton_dequant_kernel,
363383
):
364384
new_size = data_bits.size()
365385
if elem_dtype == DTYPE_FP4:
@@ -417,6 +437,9 @@ def __new__(
417437
self._elem_dtype = elem_dtype
418438
self._block_size = block_size
419439
self._orig_dtype = orig_dtype
440+
self._use_fp4_custom_triton_dequant_kernel = (
441+
use_fp4_custom_triton_dequant_kernel
442+
)
420443
return self
421444

422445
def __repr__(self):
@@ -443,14 +466,22 @@ def to_mx(
443466
elem_dtype: Union[torch.dtype, str],
444467
block_size: int = BLOCK_SIZE_DEFAULT,
445468
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
469+
use_fp4_custom_triton_dequant_kernel: bool = False,
446470
):
447-
return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode)
471+
return ToMXConstrFunc.apply(
472+
data_hp,
473+
elem_dtype,
474+
block_size,
475+
scaling_mode,
476+
use_fp4_custom_triton_dequant_kernel,
477+
)
448478

449479
def __tensor_flatten__(self):
450480
ctx = {
451481
"_elem_dtype": self._elem_dtype,
452482
"_block_size": self._block_size,
453483
"_orig_dtype": self._orig_dtype,
484+
"_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel,
454485
}
455486
return ["_scale_e8m0", "_data"], ctx
456487

@@ -467,6 +498,7 @@ def __tensor_unflatten__(
467498
metadata["_elem_dtype"],
468499
metadata["_block_size"],
469500
metadata["_orig_dtype"],
501+
metadata["_use_fp4_custom_triton_dequant_kernel"],
470502
)
471503

472504
# Do not force the MXTensor type on the returned tensor

0 commit comments

Comments
 (0)