Skip to content

Commit 0b0e401

Browse files
kunal-vaishnaviaciddelgado
authored andcommitted
Add logit softcapping to GQA (#876)
### Description This PR adds the `softcap` attribute to the `GroupQueryAttention` op. ### Motivation and Context This PR helps resolve the `NaN` output issue with Gemma-2 raised in [this issue](#692).
1 parent e7cc669 commit 0b0e401

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

Diff for: src/python/py/models/builder.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
147147
}
148148

149149
# LayerNorm-specific variables
150+
epsilon = config.rms_norm_eps if hasattr(config, "rms_norm_eps") else 1e-06
150151
self.layernorm_attrs = {
151152
"simple": True, # Use SimplifiedLayerNorm/SkipSimplifiedLayerNorm vs. LayerNorm/SkipLayerNorm
152153
"first_layernorm": True, # 1st LayerNorm = LayerNorm, then SkipLayerNorm for all subsequent LayerNorms
@@ -156,6 +157,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
156157
"output_0": "", # Output 0 for LayerNorm and SkipLayerNorm
157158
"output_3": "", # Output 3 for SkipLayerNorm
158159
"add_offset": 0, # Offset value for LayerNorm weight
160+
"epsilon": epsilon, # Epsilon value to avoid `sqrt(0)` in LayerNorm
159161
}
160162

161163
# MatMul-specific variables
@@ -212,6 +214,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
212214
}
213215

214216
# Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.)
217+
softcap = config.attn_logit_softcapping if hasattr(config, "attn_logit_softcapping") else 0.0 # default is 0.0 in GroupQueryAttention kernel
218+
215219
# Block-sparse attention-specific variables
216220
sparse_block_size = config.blocksparse_block_size if hasattr(config, "blocksparse_block_size") else 0
217221
kernel_block_size = config.blocksparse_triton_kernel_block_size if hasattr(config, "blocksparse_triton_kernel_block_size") else 0
@@ -224,6 +228,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
224228
"v_path": "", # V path to attention
225229
"op_type": "MultiHeadAttention", # Attention op to use
226230
"scale": 1 / np.sqrt(self.head_size), # Scale value after calculating Q x K' in attention
231+
"softcap": softcap, # Softcap value to prevent values from exploding in attention
227232
"use_rotemb_in_attn": False, # Use rotary embeddings within attention (instead of a separate RotaryEmbedding op)
228233
"use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V)
229234
"block_sparse": { # Block-sparse attention-specific variables
@@ -969,7 +974,7 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location):
969974

970975
name = f"/model/layers.{layer_id}/{location}_layernorm/{'Skip' if skip else ''}LayerNorm"
971976
op_type = f"{'Skip' if skip else ''}{'Simplified' if simple else ''}LayerNormalization"
972-
kwargs = {"epsilon": 9.999999747378752e-06}
977+
kwargs = {"epsilon": self.layernorm_attrs["epsilon"]}
973978
if not skip:
974979
kwargs.update({"axis": -1, "stash_type": 1})
975980

@@ -1381,7 +1386,7 @@ def make_group_query_attention(self, name, **kwargs):
13811386
self.make_node(
13821387
"GroupQueryAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft",
13831388
num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, scale=self.attention_attrs["scale"], # local_window_size=self.window_size, # Disable sliding window attribute temporarily
1384-
do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"],
1389+
softcap=self.attention_attrs["softcap"], do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"],
13851390
)
13861391
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads])
13871392

0 commit comments

Comments
 (0)