feat(metal): SDPA with Relative Positional Embeddings & LayerNorm Stability #3280
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR introduces a new fused Scaled Dot Product Attention (SDPA) kernel that supports relative positional encodings, primarily to support the SAM3 backbone. It also extends SDPA support to smaller head dimensions (
head_dim=16) and refactors the Metal LayerNorm kernel to use a two-pass algorithm for better numerical stability.Key Changes
SDPA with Relative Positional Encodings
sdpa_with_rel_postocandle-nnops.attention_rel_poskernel in Metal that injectsterm_handterm_wbiases directly into the attention scores.Small Head Dimension Support
head_dim=16across SDPA kernels.sdpa_vector_p16) that processes 2 heads per SIMD group for these smaller dimensions.LayerNorm Stability
layernormkernel inreduce.metal.Fixes & Tests
conv_transpose2d_bf16to verify the fix for the BF16 ConvTranspose2d bug (where it was incorrectly dispatching to 1D kernels).fused_opsincandle-metal-kernels.Motivation
The primary driver was porting SAM3, which requires specific attention variants not previously supported. While debugging inference mismatches, we also identified that the previous single-pass LayerNorm was accumulating too much error in deep networks, necessitating the rewrite.
Testing
conv_transpose2d_bf16test case.