[Bugfix] Fix precision loss in LoRA-wrapped RowParallelLinear by fusing bias into GEMM #28972
+6
−15
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Issue:
LoRA-wrapped RowParallelLinear was adding bias as a separate
bfloat16operation instead of fusing it into the GEMM kernel like the unwrapped layer does. This caused precision loss because the fused kernel can accumulate in higher precision (FP32) before converting to bfloat16, while separate addition incurs additional rounding errors. The discrepancy appeared even with zero LoRA weights when comparing LoRA-wrapped vs merged weight results.Fix:
Pass bias to apply() only on rank 0 (or when skip_bias_add=False), allowing the quantization method to fuse bias addition with matrix multiplication in the GEMM kernel. This matches the unwrapped layer's behavior and eliminates precision discrepancies.