Skip to content

Commit a3753cc

Browse files
author
Xiaofei Han
committed
change variable
1 parent 81744e7 commit a3753cc

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

onnxruntime/core/providers/webgpu/math/gemm.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
6969
}
7070

7171
// calculateBeta
72-
if(has_C_input_) {
72+
if(need_handle_bias_) {
7373
const ShaderVariableHelper& C = shader.AddInput("C", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
7474
shader.MainFunctionBody() << " value = value + A_value_t(uniforms.beta) * "
7575
<< C.GetByOffset(C.BroadcastedIndicesToOffset("vec2(m, n)", output)) << ";\n";
@@ -117,11 +117,11 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
117117
int64_t num_tiles_n = (N + TILE_SIZE - 1) / TILE_SIZE;
118118
int64_t dispatch_size = num_tiles_m * num_tiles_n;
119119

120-
GemmProgram program{transA_, transB_, alpha_, beta_, C != nullptr};
120+
GemmProgram program{transA_, transB_, alpha_, beta_, C&&beta_};
121121
program.AddInputs({{A, ProgramTensorMetadataDependency::Type},
122122
{B, ProgramTensorMetadataDependency::Type}});
123123

124-
if(C != nullptr) {
124+
if(C && beta_) {
125125
program.AddInput({C, ProgramTensorMetadataDependency::Rank});
126126
}
127127

onnxruntime/core/providers/webgpu/math/gemm.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ namespace webgpu {
1212

1313
class GemmProgram final : public Program<GemmProgram> {
1414
public:
15-
GemmProgram(bool transA, bool transB, float alpha, float beta, bool has_C_input)
15+
GemmProgram(bool transA, bool transB, float alpha, float beta, bool need_handle_bias)
1616
: Program{"Gemm"},
1717
transA_{transA},
1818
transB_{transB},
1919
alpha_{alpha},
2020
beta_{beta},
21-
has_C_input_{has_C_input} {}
21+
need_handle_bias_{need_handle_bias} {}
2222

2323
Status GenerateShaderCode(ShaderHelper& sh) const override;
2424

@@ -35,7 +35,7 @@ class GemmProgram final : public Program<GemmProgram> {
3535
bool transB_;
3636
float alpha_;
3737
float beta_;
38-
bool has_C_input_;
38+
bool need_handle_bias_;
3939
};
4040

4141
class Gemm final : public WebGpuKernel {

0 commit comments

Comments
 (0)