Skip to content

Conversation

@amritsingh183
Copy link
Contributor

Description

This PR introduces a new fused Scaled Dot Product Attention (SDPA) kernel that supports relative positional encodings, primarily to support the SAM3 backbone. It also extends SDPA support to smaller head dimensions (head_dim=16) and refactors the Metal LayerNorm kernel to use a two-pass algorithm for better numerical stability.

Key Changes

SDPA with Relative Positional Encodings

  • Added sdpa_with_rel_pos to candle-nn ops.
  • Implemented a new attention_rel_pos kernel in Metal that injects term_h and term_w biases directly into the attention scores.
  • This is a critical component for the SAM3 text encoder and vision backbone.

Small Head Dimension Support

  • Added support for head_dim=16 across SDPA kernels.
  • Implemented a highly optimized packed vector kernel (sdpa_vector_p16) that processes 2 heads per SIMD group for these smaller dimensions.

LayerNorm Stability

  • Rewrote the layernorm kernel in reduce.metal.
  • Switched from a single-pass algorithm to a two-pass approach (calculate mean -> barrier -> calculate variance). This significantly reduces floating-point errors and catastrophic cancellation observed in some models.

Fixes & Tests

  • Added a regression test conv_transpose2d_bf16 to verify the fix for the BF16 ConvTranspose2d bug (where it was incorrectly dispatching to 1D kernels).
  • Exposed fused_ops in candle-metal-kernels.

Motivation

The primary driver was porting SAM3, which requires specific attention variants not previously supported. While debugging inference mismatches, we also identified that the previous single-pass LayerNorm was accumulating too much error in deep networks, necessitating the rewrite.

Testing

  • Added conv_transpose2d_bf16 test case.
  • Verified SDPA output against PyTorch reference for SAM3 implementation.

- Small Head Dimension Support
- LayerNorm Stability
- BF16 ConvTranspose2d fix
@ivarflakstad
Copy link
Member

Hey!
Hmm. This is like 4 PRs in one hehe. One of them you've actually already opened though (bf16 conv2d).

I'm definitely interested in an updated sdpa kernel, but preferrably as a standalone PR.
Is it from mlx? Can't seem to find it in their repo. Is it vibe coded?

I don't think we need to update the layernorm kernels, because I've just opened a PR here which fixes precision and improves performance. The two pass approach would also fix precision, but degrades performance.

The "fused" ops aren't actually fused with anything (aka not fused), so I'll pass on those as well.

@amritsingh183
Copy link
Contributor Author

  • let me create a new PR for SDPA (it was created by LLM council of 3 Claude opus 4.5 and 2 gemini 3 members)
  • let me look into the Fusion and see how to get it right (it's a bummer for me)
  • i looked at [Metal] improve normalization #3283 and it is what i need

let me close this PR

@ivarflakstad
Copy link
Member

  • let me create a new PR for SDPA (it was created by LLM council of 3 Claude opus 4.5 and 2 gemini 3 members)

If the council manages to get it right (which to be honest I kind of doubt) then it would probably be a good idea to get a PR opened against mlx, since that is where our sdpa kernel originates from. To ensure correctness you probably want fairly extensive testing.

  • let me look into the Fusion and see how to get it right (it's a bummer for me)

Just fyi we don't need the ops that were attempted fused as fused ops. They're fine as is.

Glad to hear it! 👍

@amritsingh183
Copy link
Contributor Author

Got it about the fusion part. The council for most time has been a waste of time, i think the current LLMs are not there yet. But it was worth a try hoping if it helps the candle community.. i need to be more thorough though the next time with tests bcoz wasting time of the maintainers on bad PR is also not acceptable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants