Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][webgpu] Add GEMM implementation #24023

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
{ name: 'alpha', type: 'f32' },
{ name: 'beta', type: 'f32' },
];
return `
const result = `
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
${shaderHelper.mainStart()}
Expand All @@ -134,6 +134,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
})()}
output[global_idx] = value;
}`;
console.log("xiaofeihan:", result);
return result;
};

const getShaderSourceShared = (shaderHelper: ShaderHelper) => {
Expand Down
146 changes: 146 additions & 0 deletions onnxruntime/core/providers/webgpu/math/gemm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
// Licensed under the MIT License.

#include "core/providers/webgpu/math/gemm.h"

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

#define WEBGPU_GEMM_VERSIONED_KERNEL(start, end) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
start, \
end, \
kWebGpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", WebGpuSupportedNumberTypes()), \
Gemm);

#define WEBGPU_GEMM_KERNEL(version) \
ONNX_OPERATOR_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
version, \
kWebGpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", WebGpuSupportedNumberTypes()), \
Gemm);

WEBGPU_GEMM_VERSIONED_KERNEL(7, 8)
WEBGPU_GEMM_VERSIONED_KERNEL(9, 10)
WEBGPU_GEMM_VERSIONED_KERNEL(11, 12)
WEBGPU_GEMM_KERNEL(13)

Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& A = shader.AddInput("A", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);

const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< " let m = global_idx / uniforms.N;\n"
<< " let n = global_idx % uniforms.N;\n"
<< " var value = A_value_t(0);\n"
<< "\n"
<< " for (var k = 0u; k < uniforms.K; k = k + 1u) {\n";

if (transA_ && transB_) {
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m")
<< " * " << B.GetByOffset("n * uniforms.K + k") << ";\n";
} else if (transA_ && !transB_) {
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m")
<< " * " << B.GetByOffset("k * uniforms.N + n") << ";\n";
} else if (!transA_ && transB_) {
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k")
<< " * " << B.GetByOffset("n * uniforms.K + k") << ";\n";
} else {
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k")
<< " * " << B.GetByOffset("k * uniforms.N + n") << ";\n";
}
shader.MainFunctionBody() << " }\n"
<< "\n";
// calculateAlpha
if(alpha_) {
shader.MainFunctionBody() << " value = value * A_value_t(uniforms.alpha);\n";
}

// calculateBeta
if(has_C_input_) {
const ShaderVariableHelper& C = shader.AddInput("C", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
if(beta_) {
shader.MainFunctionBody() << " value = value + A_value_t(uniforms.beta) * "
<< C.GetByOffset(C.BroadcastedIndicesToOffset("vec2(m, n)", output)) << ";\n";
}
}

shader.MainFunctionBody() << output.SetByOffset("global_idx", "value") << "\n";

return Status::OK();
}

Status Gemm::ComputeInternal(ComputeContext& context) const {
const auto* A = context.Input<Tensor>(0);
const auto* B = context.Input<Tensor>(1);
const auto* C = context.Input<Tensor>(2);

if (A == nullptr || B == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Gemm requires input tensors A and B.");
}

const auto& A_shape = A->Shape();
const auto& B_shape = B->Shape();

if (A_shape.NumDimensions() != 2 || B_shape.NumDimensions() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input tensors A and B must be 2 dimensional.");
}

int64_t M = transA_ ? A_shape[1] : A_shape[0];
int64_t K = transA_ ? A_shape[0] : A_shape[1];
int64_t N = transB_ ? B_shape[0] : B_shape[1];

if ((transA_ ? A_shape[0] : A_shape[1]) != (transB_ ? B_shape[1] : B_shape[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inner dimensions of A and B must match.");
}

std::vector<int64_t> output_dims{M, N};

Check warning on line 109 in onnxruntime/core/providers/webgpu/math/gemm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/gemm.cc:109: Add #include <vector> for vector<> [build/include_what_you_use] [4]
auto* Y = context.Output(0, output_dims);
int64_t output_size = Y->Shape().Size();

if (output_size == 0) {
return Status::OK();
}

constexpr size_t TILE_SIZE = 16;
int64_t num_tiles_m = (M + TILE_SIZE - 1) / TILE_SIZE;
int64_t num_tiles_n = (N + TILE_SIZE - 1) / TILE_SIZE;
int64_t dispatch_size = num_tiles_m * num_tiles_n;

GemmProgram program{transA_, transB_, alpha_, beta_, C != nullptr};
program.AddInputs({{A, ProgramTensorMetadataDependency::Type},
{B, ProgramTensorMetadataDependency::Type}});

if(C != nullptr) {
program.AddInput({C, ProgramTensorMetadataDependency::Rank});
}

program.AddOutputs({Y})
.SetDispatchGroupSize(dispatch_size)
.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1)
.AddUniformVariables({
{static_cast<uint32_t>(output_size)}, // output_size
{static_cast<uint32_t>(M)}, // M
{static_cast<uint32_t>(N)}, // N
{static_cast<uint32_t>(K)}, // K
{alpha_}, // alpha
{beta_} // beta
});

return context.RunProgram(program);
}

} // namespace webgpu
} // namespace onnxruntime
66 changes: 66 additions & 0 deletions onnxruntime/core/providers/webgpu/math/gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/program.h"

namespace onnxruntime {
namespace webgpu {

class GemmProgram final : public Program<GemmProgram> {
public:
GemmProgram(bool transA, bool transB, float alpha, float beta, bool has_C_input)
: Program{"Gemm"},
transA_{transA},
transB_{transB},
alpha_{alpha},
beta_{beta},
has_C_input_{has_C_input} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"output_size", ProgramUniformVariableDataType::Uint32},
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"alpha", ProgramUniformVariableDataType::Float32},
{"beta", ProgramUniformVariableDataType::Float32});

private:
bool transA_;
bool transB_;
float alpha_;
float beta_;
bool has_C_input_;
};

class Gemm final : public WebGpuKernel {
public:
Gemm(const OpKernelInfo& info) : WebGpuKernel(info) {
int64_t transA_temp;
info.GetAttrOrDefault("transA", &transA_temp, static_cast<int64_t>(0));
transA_ = transA_temp != 0;

int64_t transB_temp;
info.GetAttrOrDefault("transB", &transB_temp, static_cast<int64_t>(0));
transB_ = transB_temp != 0;

info.GetAttrOrDefault("alpha", &alpha_, 1.0f);
info.GetAttrOrDefault("beta", &beta_, 1.0f);
}

Status ComputeInternal(ComputeContext& context) const override;

private:
bool transA_;
bool transB_;
float alpha_;
float beta_;
};

} // namespace webgpu
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalMaxPool)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gemm)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MatMul)>,

Expand Down
12 changes: 11 additions & 1 deletion samples/nodejs/01_basic-usage/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@ async function main() {
// the model in this example contains a single MatMul node
// it has 2 inputs: 'a'(float32, 3x4) and 'b'(float32, 4x3)
// it has 1 output: 'c'(float32, 3x3)
const session = await ort.InferenceSession.create('./model.onnx');
// const session = await ort.InferenceSession.create('./model.onnx');

// WebGPU
const session = await ort.InferenceSession.create('./model.onnx', {
executionProviders: [
{
name: 'webgpu',
validationMode: 'wgpuOnly',
storageBufferCacheMode: 'bucket'
}
]});

// prepare inputs. a tensor need its corresponding TypedArray as data
const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
Expand Down
2 changes: 1 addition & 1 deletion samples/nodejs/01_basic-usage/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
"test": "echo \"Error: no test specified\" && exit 1"
},
"dependencies": {
"onnxruntime-node": "^1.20.1"
"onnxruntime-node": "file:../../../js/node"
}
}
Loading