You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Support mixed MX element dtype in mx_mm function and MXLinear. (#1667)
* Support mixed MX element dtype in `mx_mm` function.
Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients.
* Support (input, weight, gradient) element dtype tuple in MXLinear layer factory method.
Passing a tuple of 3 element dtypes avoids introducing a breaking change in the current interface
of `MXLinear` and `swap_linear_with_mx_linear`.
Some additional unit test coverage has been added on MXLinear.
* Using default `elem_dtype` argument and optional weight/grad overrides.
Copy file name to clipboardexpand all lines: torchao/prototype/mx_formats/README.md
+4-5
Original file line number
Diff line number
Diff line change
@@ -2,8 +2,8 @@
2
2
3
3
This is a POC of training and inference with tensors in the MX format from the OCP spec (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) in native PyTorch.
4
4
5
-
Note that the current version of the code is written for readability and
6
-
numerical correctness and not yet for optimal performance. We welcome
5
+
Note that the current version of the code is written for readability and
6
+
numerical correctness and not yet for optimal performance. We welcome
7
7
contributions on performance improvements.
8
8
9
9
Note that there are no BC guarantees at the moment and we plan to evolve
@@ -44,8 +44,7 @@ from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
44
44
45
45
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
0 commit comments