You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Refactors the MX gemm emulation code to properly emulate the memory layout
constraints we expect from the future mx-enabled hardware, where we
expect:
* the first argument to the mx gemm to be required row-major memory
format
* the second argument to the mx gemm to be required col-major memory
format
Note that two morally unrelated issues were uncovered with this
refactor:
1. when autocast is on, compile is no longer matching eager numerics.
Since the "before this PR" state isn't really representative of the
world, I'm treating this as a newly uncovered issue, and we can fix
it in a future PR.
2. our transpose logic for fp4 packed into two elements per byte doesn't
work for tensors of shape (M, 1), because we currently rely on the
`is_contiguous()` function to see if our tensor was transposed. We
could work around, but punting that until a time that becomes
important. I expect most tensors in real world usage with MX to not
hit this case.
Test Plan:
```
pytest test/prototype/mx_formats/ -s -x
```
Reviewers:
Subscribers:
Tasks:
Tags:
ghstack-source-id: af87d8b65132372f4915312ea71482f6862c4df2
ghstack-comment-id: 2605962974
Pull Request resolved: #1593
0 commit comments