|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "core/providers/webgpu/math/gemm.h" |
| 5 | + |
| 6 | +#include "core/providers/webgpu/shader_helper.h" |
| 7 | +#include "core/providers/webgpu/webgpu_supported_types.h" |
| 8 | + |
| 9 | +namespace onnxruntime { |
| 10 | +namespace webgpu { |
| 11 | + |
| 12 | +#define WEBGPU_GEMM_VERSIONED_KERNEL(start, end) \ |
| 13 | + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ |
| 14 | + Gemm, \ |
| 15 | + kOnnxDomain, \ |
| 16 | + start, \ |
| 17 | + end, \ |
| 18 | + kWebGpuExecutionProvider, \ |
| 19 | + (*KernelDefBuilder::Create()) \ |
| 20 | + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ |
| 21 | + Gemm); |
| 22 | + |
| 23 | +#define WEBGPU_GEMM_KERNEL(version) \ |
| 24 | + ONNX_OPERATOR_KERNEL_EX( \ |
| 25 | + Gemm, \ |
| 26 | + kOnnxDomain, \ |
| 27 | + version, \ |
| 28 | + kWebGpuExecutionProvider, \ |
| 29 | + (*KernelDefBuilder::Create()) \ |
| 30 | + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ |
| 31 | + Gemm); |
| 32 | + |
| 33 | +WEBGPU_GEMM_VERSIONED_KERNEL(7, 8) |
| 34 | +WEBGPU_GEMM_VERSIONED_KERNEL(9, 10) |
| 35 | +WEBGPU_GEMM_VERSIONED_KERNEL(11, 12) |
| 36 | +WEBGPU_GEMM_KERNEL(13) |
| 37 | + |
| 38 | +Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { |
| 39 | + const ShaderVariableHelper& A = shader.AddInput("A", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); |
| 40 | + const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); |
| 41 | + |
| 42 | + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); |
| 43 | + |
| 44 | + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") |
| 45 | + << " let m = global_idx / uniforms.N;\n" |
| 46 | + << " let n = global_idx % uniforms.N;\n" |
| 47 | + << " var value = A_value_t(0);\n" |
| 48 | + << "\n" |
| 49 | + << " for (var k = 0u; k < uniforms.K; k = k + 1u) {\n"; |
| 50 | + |
| 51 | + if (transA_ && transB_) { |
| 52 | + shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m") |
| 53 | + << " * " << B.GetByOffset("n * uniforms.K + k") << ";\n"; |
| 54 | + } else if (transA_ && !transB_) { |
| 55 | + shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m") |
| 56 | + << " * " << B.GetByOffset("k * uniforms.N + n") << ";\n"; |
| 57 | + } else if (!transA_ && transB_) { |
| 58 | + shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k") |
| 59 | + << " * " << B.GetByOffset("n * uniforms.K + k") << ";\n"; |
| 60 | + } else { |
| 61 | + shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k") |
| 62 | + << " * " << B.GetByOffset("k * uniforms.N + n") << ";\n"; |
| 63 | + } |
| 64 | + shader.MainFunctionBody() << " }\n" |
| 65 | + << "\n"; |
| 66 | + // Calculate Alpha |
| 67 | + if (alpha_) { |
| 68 | + shader.MainFunctionBody() << " value = value * A_value_t(uniforms.alpha);\n"; |
| 69 | + } |
| 70 | + |
| 71 | + // Calculate Bias |
| 72 | + if (need_handle_bias_) { |
| 73 | + const ShaderVariableHelper& C = shader.AddInput("C", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); |
| 74 | + shader.MainFunctionBody() << " value = value + A_value_t(uniforms.beta) * " |
| 75 | + << C.GetByOffset(C.BroadcastedIndicesToOffset("vec2(m, n)", output)) << ";\n"; |
| 76 | + } |
| 77 | + |
| 78 | + shader.MainFunctionBody() << output.SetByOffset("global_idx", "value") << "\n"; |
| 79 | + |
| 80 | + return Status::OK(); |
| 81 | +} |
| 82 | + |
| 83 | +Status Gemm::ComputeInternal(ComputeContext& context) const { |
| 84 | + const auto* A = context.Input<Tensor>(0); |
| 85 | + const auto* B = context.Input<Tensor>(1); |
| 86 | + const auto* C = context.Input<Tensor>(2); |
| 87 | + |
| 88 | + if (A == nullptr || B == nullptr) { |
| 89 | + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Gemm requires input tensors A and B."); |
| 90 | + } |
| 91 | + |
| 92 | + const auto& A_shape = A->Shape(); |
| 93 | + const auto& B_shape = B->Shape(); |
| 94 | + |
| 95 | + if (A_shape.NumDimensions() != 2 || B_shape.NumDimensions() != 2) { |
| 96 | + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input tensors A and B must be 2 dimensional."); |
| 97 | + } |
| 98 | + |
| 99 | + int64_t M = transA_ ? A_shape[1] : A_shape[0]; |
| 100 | + int64_t K = transA_ ? A_shape[0] : A_shape[1]; |
| 101 | + int64_t N = transB_ ? B_shape[0] : B_shape[1]; |
| 102 | + |
| 103 | + if ((transA_ ? A_shape[0] : A_shape[1]) != (transB_ ? B_shape[1] : B_shape[0])) { |
| 104 | + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inner dimensions of A and B must match."); |
| 105 | + } |
| 106 | + |
| 107 | + std::vector<int64_t> output_dims{M, N}; |
| 108 | + auto* Y = context.Output(0, output_dims); |
| 109 | + int64_t output_size = Y->Shape().Size(); |
| 110 | + |
| 111 | + if (output_size == 0) { |
| 112 | + return Status::OK(); |
| 113 | + } |
| 114 | + |
| 115 | + constexpr size_t TILE_SIZE = 16; |
| 116 | + int64_t num_tiles_m = (M + TILE_SIZE - 1) / TILE_SIZE; |
| 117 | + int64_t num_tiles_n = (N + TILE_SIZE - 1) / TILE_SIZE; |
| 118 | + int64_t dispatch_size = num_tiles_m * num_tiles_n; |
| 119 | + |
| 120 | + GemmProgram program{transA_, transB_, alpha_, beta_, C && beta_}; |
| 121 | + program.AddInputs({{A, ProgramTensorMetadataDependency::Type}, |
| 122 | + {B, ProgramTensorMetadataDependency::Type}}); |
| 123 | + |
| 124 | + if (C && beta_) { |
| 125 | + program.AddInput({C, ProgramTensorMetadataDependency::Rank}); |
| 126 | + } |
| 127 | + |
| 128 | + program.AddOutputs({Y}) |
| 129 | + .SetDispatchGroupSize(dispatch_size) |
| 130 | + .SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1) |
| 131 | + .AddUniformVariables({ |
| 132 | + {static_cast<uint32_t>(output_size)}, // output_size |
| 133 | + {static_cast<uint32_t>(M)}, // M |
| 134 | + {static_cast<uint32_t>(N)}, // N |
| 135 | + {static_cast<uint32_t>(K)}, // K |
| 136 | + {alpha_}, // alpha |
| 137 | + {beta_} // beta |
| 138 | + }); |
| 139 | + |
| 140 | + return context.RunProgram(program); |
| 141 | +} |
| 142 | + |
| 143 | +} // namespace webgpu |
| 144 | +} // namespace onnxruntime |
0 commit comments