-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU] Add fp16 support to sparse attention #24015
base: main
Are you sure you want to change the base?
Conversation
q, head_size, k, head_size, output, total_seq_len, | ||
MLFloat16(alpha).val, static_cast<uint16_t>(0) /*beta*/, nullptr); | ||
} else { | ||
size_t bytes = head_size * (sequence_length + total_seq_len) * sizeof(float); |
Check failure
Code scanning / CodeQL
Multiplication result converted to larger type High
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix AI about 11 hours ago
To fix the problem, we need to ensure that the multiplication is performed using a larger integer type to avoid overflow. This can be done by casting one of the operands to size_t
before performing the multiplication. This way, the multiplication will be done in the larger integer type, preventing overflow.
- Cast one of the operands to
size_t
before performing the multiplication. - Specifically, cast
head_size
tosize_t
in the multiplication expression on line 236. - No additional methods, imports, or definitions are needed to implement this change.
-
Copy modified line R236
@@ -235,3 +235,3 @@ | ||
} else { | ||
size_t bytes = head_size * (sequence_length + total_seq_len) * sizeof(float); | ||
size_t bytes = static_cast<size_t>(head_size) * (sequence_length + total_seq_len) * sizeof(float); | ||
auto q_k_fp32 = allocator->Alloc(bytes); |
BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator)); | ||
|
||
float* q_fp32 = static_cast<float*>(q_k_fp32); | ||
MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length); |
Check failure
Code scanning / CodeQL
Multiplication result converted to larger type High
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix AI about 11 hours ago
To fix the problem, we need to ensure that the multiplication is performed using a larger integer type to avoid overflow. This can be done by casting one or both of the operands to size_t
before performing the multiplication. This way, the multiplication will be done in the size_t
type, which has a larger range than int
.
Specifically, we will modify the line where the multiplication occurs to cast head_size
to size_t
before multiplying it by sequence_length
.
-
Copy modified line R241
@@ -240,3 +240,3 @@ | ||
float* q_fp32 = static_cast<float*>(q_k_fp32); | ||
MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length); | ||
MlasConvertHalfToFloatBuffer(q, q_fp32, static_cast<size_t>(head_size) * sequence_length); | ||
|
MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length); | ||
|
||
float* k_fp32 = q_fp32 + head_size * sequence_length; | ||
MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seq_len); |
Check failure
Code scanning / CodeQL
Multiplication result converted to larger type High
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix AI about 11 hours ago
To fix the problem, we need to ensure that the multiplication is performed using the larger type (size_t
) to avoid overflow. This can be achieved by casting one of the operands to size_t
before performing the multiplication. This way, the multiplication will be done in the larger type, preventing overflow.
- Cast one of the operands (
head_size
ortotal_seq_len
) tosize_t
before the multiplication. - This change should be made on line 244 where the multiplication occurs.
-
Copy modified line R244
@@ -243,3 +243,3 @@ | ||
float* k_fp32 = q_fp32 + head_size * sequence_length; | ||
MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seq_len); | ||
MlasConvertHalfToFloatBuffer(k, k_fp32, static_cast<size_t>(head_size) * total_seq_len); | ||
|
v, head_size, output_current, hidden_size, | ||
MLFloat16(1.0f).val, static_cast<uint16_t>(0) /*beta*/, nullptr); | ||
} else { | ||
size_t bytes = head_size * total_seq_len * sizeof(float); |
Check failure
Code scanning / CodeQL
Multiplication result converted to larger type High
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix AI about 11 hours ago
To fix the problem, we need to ensure that the multiplication is performed using the larger integer type to avoid overflow. This can be done by casting one of the operands to size_t
before performing the multiplication. This way, the multiplication will be done in the larger type, preventing overflow.
The best way to fix this is to cast head_size
to size_t
before the multiplication on line 450. This change will ensure that the multiplication is performed using size_t
, avoiding any potential overflow.
-
Copy modified line R450
@@ -449,3 +449,3 @@ | ||
} else { | ||
size_t bytes = head_size * total_seq_len * sizeof(float); | ||
size_t bytes = static_cast<size_t>(head_size) * total_seq_len * sizeof(float); | ||
auto v_fp32 = allocator->Alloc(bytes); |
BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator)); | ||
|
||
float* v_fp32_ptr = static_cast<float*>(v_fp32); | ||
MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seq_len); |
Check failure
Code scanning / CodeQL
Multiplication result converted to larger type High
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix AI about 11 hours ago
To fix the problem, we need to ensure that the multiplication is performed using a larger integer type to avoid overflow. This can be done by casting one of the operands to size_t
before performing the multiplication. This way, the multiplication will be done using the size_t
type, which has a larger range than int
.
The best way to fix this without changing existing functionality is to cast head_size
to size_t
before the multiplication. This change should be made on line 455 of the file onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h
.
-
Copy modified line R455
@@ -454,3 +454,3 @@ | ||
float* v_fp32_ptr = static_cast<float*>(v_fp32); | ||
MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seq_len); | ||
MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, static_cast<size_t>(head_size) * total_seq_len); | ||
|
Description
Add fp16 support to sparse attention
Motivation and Context
Generalize models for CPU and GPU