-
Notifications
You must be signed in to change notification settings - Fork 408
Description
In the current implementation of the MXLinear layer:
def forward(self, x):
x_mx = MXTensor.to_mx(x, self.elem_dtype, self.block_size)
w_mx = MXTensor.to_mx(self.weight, self.elem_dtype, self.block_size)
y = F.linear(x_mx, w_mx, self.bias)
y = NoopFwToMXBw.apply(y, self.elem_dtype, self.block_size)
return ythere is only a single MX quantization step of the output gradient (in NoopFwToMXBw).
However, following the MX microscaling paper, there should be 4 quantizations happening: two for output gradient (on 2 different axes), one for the activation and one for the weights (different from the forward ones).

Why does it matter: even though not officially confirmed by hardware vendors, it is clear that MX matmuls can only be fully optimized if the quantization axis correspond to the reduction axis for both operands. Hence, running MX backward pass on next gen hardware will require the 4 quantization steps presented above. Changing of axis for the MX quantization result in a different quantization error, meaning that the current implementation is potentially not giving a full picture of what will be MX training on real hardware.
Potential fix: I believe we need a full implementation of forward+backward pass of blockwise_quantize_linear function, manually handling the backward pass quantization steps.