Skip to content

Bug: MXLinear backward pass implementation #1501

@balancap

Description

@balancap

In the current implementation of the MXLinear layer:

def forward(self, x):
       x_mx = MXTensor.to_mx(x, self.elem_dtype, self.block_size)
       w_mx = MXTensor.to_mx(self.weight, self.elem_dtype, self.block_size)
       y = F.linear(x_mx, w_mx, self.bias)
       y = NoopFwToMXBw.apply(y, self.elem_dtype, self.block_size)
       return y

there is only a single MX quantization step of the output gradient (in NoopFwToMXBw).

However, following the MX microscaling paper, there should be 4 quantizations happening: two for output gradient (on 2 different axes), one for the activation and one for the weights (different from the forward ones).
microscaling-fwd-bwd

Why does it matter: even though not officially confirmed by hardware vendors, it is clear that MX matmuls can only be fully optimized if the quantization axis correspond to the reduction axis for both operands. Hence, running MX backward pass on next gen hardware will require the 4 quantization steps presented above. Changing of axis for the MX quantization result in a different quantization error, meaning that the current implementation is potentially not giving a full picture of what will be MX training on real hardware.

Potential fix: I believe we need a full implementation of forward+backward pass of blockwise_quantize_linear function, manually handling the backward pass quantization steps.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions