Skip to content

Commit 16d6f39

Browse files
authored
[webgpu] Fix the continuation issue (#23999)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent fe43537 commit 16d6f39

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
300300
qk_4 = qk_4 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start+12, head_idx);
301301
}
302302
303-
let seq_causal_length = select(uniforms.total_sequence_length, q_idx_global + 1, uniforms.is_gqa > 0);
303+
let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_gqa > 0);
304304
// Neuter qk values where K is out of bounds.
305305
qk_1[0] = select(min_value, qk_1[0], k_start+0 < seq_causal_length);
306306
qk_1[1] = select(min_value, qk_1[1], k_start+1 < seq_causal_length);
@@ -451,6 +451,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
451451
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
452452
{static_cast<uint32_t>(parameters.total_sequence_length_)},
453453
{static_cast<uint32_t>(parameters.past_present_share_buffer_ ? parameters.past_sequence_length_ : parameters.total_sequence_length_)},
454+
{static_cast<uint32_t>(parameters.total_sequence_length_ - parameters.kv_sequence_length_)},
454455
{static_cast<uint32_t>(parameters.is_gqa_ ? 1 : 0)},
455456
{static_cast<uint32_t>(parameters.n_reps)},
456457
{alpha}});

onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
4949
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32},
5050
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
5151
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
52+
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
5253
{"is_gqa", ProgramUniformVariableDataType::Uint32},
5354
{"n_reps", ProgramUniformVariableDataType::Uint32},
5455
{"alpha", ProgramUniformVariableDataType::Float32});

0 commit comments

Comments
 (0)