Skip to content

Commit e8ce3c3

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
metal lowbit kernels: executorch ops (#1322)
Summary: Refactors kernels/mps/src/OperationUntils.h, moving MetalShaderLibrary into its own header. Integrates MPS delegate functions into lowbit.h Registers out variants for the ATen ops Registers ET ops Differential Revision: D65957345
1 parent 5eb6339 commit e8ce3c3

14 files changed

+365
-76
lines changed

torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
*/
3838
3939
#ifdef USE_ATEN
40-
using namespace at::native::mps;
40+
using at::native::mps::MetalShaderLibrary;
4141
#else
42-
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
42+
#include <torchao/experimental/kernels/mps/src/MetalShaderLibrary.h>
4343
#endif
4444
4545
static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#ifdef USE_EXECUTORCH
10+
#include <executorch/backends/apple/mps/runtime/MPSDevice.h>
11+
using executorch::backends::mps::delegate::MPSDevice;
12+
static id<MTLDevice> MTL_DEVICE = MPSDevice::getInstance()->device();
13+
#else
14+
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
15+
#endif
16+
17+
static id<MTLLibrary> compileLibraryFromSource(
18+
id<MTLDevice> device,
19+
const std::string& source) {
20+
NSError* error = nil;
21+
MTLCompileOptions* options = [MTLCompileOptions new];
22+
[options setLanguageVersion:MTLLanguageVersion3_1];
23+
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
24+
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
25+
options:options
26+
error:&error];
27+
#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling
28+
if (library == nil) {
29+
throw_exception(
30+
"Failed to compile: " + std::string(error.description.UTF8String));
31+
}
32+
#endif
33+
return library;
34+
}
35+
36+
class MetalShaderLibrary {
37+
public:
38+
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
39+
lib = compileLibraryFromSource(device, shaderSource);
40+
}
41+
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
42+
MetalShaderLibrary(MetalShaderLibrary&&) = delete;
43+
44+
id<MTLComputePipelineState> getPipelineStateForFunc(
45+
const std::string& fname) {
46+
return get_compute_pipeline_state(load_func(fname));
47+
}
48+
49+
private:
50+
std::string shaderSource;
51+
id<MTLDevice> device = MTL_DEVICE;
52+
id<MTLLibrary> lib = nil;
53+
54+
id<MTLFunction> load_func(const std::string& func_name) const {
55+
id<MTLFunction> func = [lib
56+
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
57+
#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling
58+
if (func == nil) {
59+
throw_exception("Can't get function:" + func_name);
60+
}
61+
#endif
62+
return func;
63+
}
64+
65+
id<MTLComputePipelineState> get_compute_pipeline_state(
66+
id<MTLFunction> func) const {
67+
NSError* error = nil;
68+
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
69+
#ifndef USE_EXECUTORCH // TODO(mcandales): Unify with ET error handling
70+
if (cpl == nil) {
71+
throw_exception(
72+
"Failed to construct pipeline state: " +
73+
std::string(error.description.UTF8String));
74+
}
75+
#endif
76+
return cpl;
77+
}
78+
};

torchao/experimental/kernels/mps/src/OperationUtils.h

-65
Original file line numberDiff line numberDiff line change
@@ -40,63 +40,6 @@ inline id<MTLDevice> getMetalDevice() {
4040

4141
static id<MTLDevice> MTL_DEVICE = getMetalDevice();
4242

43-
static id<MTLLibrary> compileLibraryFromSource(
44-
id<MTLDevice> device,
45-
const std::string& source) {
46-
NSError* error = nil;
47-
MTLCompileOptions* options = [MTLCompileOptions new];
48-
[options setLanguageVersion:MTLLanguageVersion3_1];
49-
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
50-
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
51-
options:options
52-
error:&error];
53-
if (library == nil) {
54-
throw_exception(
55-
"Failed to compile: " + std::string(error.description.UTF8String));
56-
}
57-
return library;
58-
}
59-
60-
class MetalShaderLibrary {
61-
public:
62-
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
63-
lib = compileLibraryFromSource(device, shaderSource);
64-
}
65-
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
66-
MetalShaderLibrary(MetalShaderLibrary&&) = delete;
67-
68-
id<MTLComputePipelineState> getPipelineStateForFunc(
69-
const std::string& fname) {
70-
return get_compute_pipeline_state(load_func(fname));
71-
}
72-
73-
private:
74-
std::string shaderSource;
75-
id<MTLDevice> device = MTL_DEVICE;
76-
id<MTLLibrary> lib = nil;
77-
78-
id<MTLFunction> load_func(const std::string& func_name) const {
79-
id<MTLFunction> func = [lib
80-
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
81-
if (func == nil) {
82-
throw_exception("Can't get function:" + func_name);
83-
}
84-
return func;
85-
}
86-
87-
id<MTLComputePipelineState> get_compute_pipeline_state(
88-
id<MTLFunction> func) const {
89-
NSError* error = nil;
90-
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
91-
if (cpl == nil) {
92-
throw_exception(
93-
"Failed to construct pipeline state: " +
94-
std::string(error.description.UTF8String));
95-
}
96-
return cpl;
97-
}
98-
};
99-
10043
class MPSStream {
10144
public:
10245
MPSStream() {
@@ -136,14 +79,6 @@ class MPSStream {
13679
id<MTLComputeCommandEncoder> _commandEncoder = nil;
13780
};
13881

139-
inline void finalize_block(MPSStream* mpsStream) {
140-
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
141-
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
142-
[encoder endEncoding];
143-
[cmdBuffer commit];
144-
[cmdBuffer waitUntilCompleted];
145-
}
146-
14782
inline MPSStream* getCurrentMPSStream() {
14883
return new MPSStream();
14984
}

torchao/experimental/kernels/mps/src/lowbit.h

+19-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
1111

1212
#include <torchao/experimental/kernels/mps/src/dispatch.h>
13-
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h>
13+
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h> // metal_lowbit_quantized_lib
1414
#include <torchao/experimental/kernels/mps/src/packing.h>
1515

1616
#include <cassert>
@@ -20,9 +20,9 @@
2020
#ifdef USE_ATEN
2121
#include <ATen/native/mps/OperationUtils.h>
2222
using namespace at::native::mps;
23-
inline void finalize_block(MPSStream* mpsStream) {}
24-
void (*dispatch_block)(dispatch_queue_t, dispatch_block_t) =
25-
dispatch_sync_with_rethrow;
23+
#elif defined(USE_EXECUTORCH)
24+
#include <executorch/backends/apple/mps/runtime/MPSStream.h>
25+
using namespace executorch::backends::mps::delegate;
2626
#else
2727
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
2828
#endif
@@ -103,7 +103,13 @@ inline void linear_lowbit_quant_weights_mps_impl(
103103
0};
104104

105105
MPSStream* mpsStream = getCurrentMPSStream();
106+
#if defined(USE_ATEN)
107+
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
108+
#elif defined(USE_EXECUTORCH)
109+
dispatch_sync(mpsStream->queue(), ^() {
110+
#else
106111
dispatch_block(mpsStream->queue(), ^() {
112+
#endif
107113
@autoreleasepool {
108114
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
109115
id<MTLComputePipelineState> cpl =
@@ -119,7 +125,15 @@ inline void linear_lowbit_quant_weights_mps_impl(
119125
length:sizeof(uint32_t) * sizes.size()
120126
atIndex:5];
121127
dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K);
122-
finalize_block(mpsStream);
128+
#if defined(USE_EXECUTORCH)
129+
ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok);
130+
#elif !defined(USE_ATEN)
131+
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
132+
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
133+
[encoder endEncoding];
134+
[cmdBuffer commit];
135+
[cmdBuffer waitUntilCompleted];
136+
#endif
123137
}
124138
});
125139
}

torchao/experimental/ops/mps/CMakeLists.txt

+24
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,27 @@ install(
5858
EXPORT _targets
5959
DESTINATION lib
6060
)
61+
62+
if(TORCHAO_BUILD_EXECUTORCH_OPS)
63+
include_directories(${CMAKE_INSTALL_PREFIX}/schema/include)
64+
include_directories(${CMAKE_INSTALL_PREFIX}/../third-party/flatbuffers/include)
65+
file(GLOB _SRCS "${CMAKE_CURRENT_SOURCE_DIR}/executorch/*.mm")
66+
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_executorch OBJECT ${_SRCS})
67+
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_executorch generated_metal_shader_lib)
68+
target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}")
69+
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1)
70+
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}")
71+
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})
72+
73+
add_library(torchao_ops_mps_executorch STATIC)
74+
target_link_libraries(torchao_ops_mps_executorch PRIVATE
75+
torchao_ops_mps_linear_fp_act_xbit_weight_executorch
76+
)
77+
install(
78+
TARGETS
79+
torchao_ops_mps_executorch
80+
torchao_ops_mps_linear_fp_act_xbit_weight_executorch
81+
EXPORT _targets
82+
DESTINATION lib
83+
)
84+
endif()

torchao/experimental/ops/mps/aten/register.mm

+39-4
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,13 @@ void check_linear_mps_args(
7070
}
7171

7272
template <int nbit>
73-
Tensor linear_mps_kernel(
73+
Tensor linear_mps_kernel_out(
7474
const Tensor& A,
7575
const Tensor& B,
7676
int64_t group_size,
7777
const Tensor& S,
78-
const Tensor& Z) {
78+
const Tensor& Z,
79+
Tensor& C) {
7980
TORCH_CHECK(
8081
A.is_mps(), __func__, ": A is on ", A.device(), " but expected on mps");
8182
TORCH_CHECK(
@@ -84,15 +85,15 @@ Tensor linear_mps_kernel(
8485
S.is_mps(), __func__, ": S is on ", S.device(), " but expected on mps");
8586
TORCH_CHECK(
8687
Z.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps");
88+
TORCH_CHECK(
89+
C.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps");
8790

8891
check_linear_mps_args<nbit>(A, B, group_size, S, Z);
8992

9093
auto M = A.size(0);
9194
auto N = B.size(0);
9295
auto K = A.size(1);
9396

94-
auto C = at::empty({M, N}, A.options());
95-
9697
LowBitQuantWeights<nbit>::linear(
9798
getMTLBufferStorage(A),
9899
getMTLBufferStorage(B),
@@ -108,6 +109,19 @@ Tensor linear_mps_kernel(
108109
return C;
109110
}
110111

112+
template <int nbit>
113+
Tensor linear_mps_kernel(
114+
const Tensor& A,
115+
const Tensor& B,
116+
int64_t group_size,
117+
const Tensor& S,
118+
const Tensor& Z) {
119+
auto M = A.size(0);
120+
auto N = B.size(0);
121+
auto C = at::empty({M, N}, A.options());
122+
return linear_mps_kernel_out<nbit>(A, B, group_size, S, Z, C);
123+
}
124+
111125
template <int nbit>
112126
Tensor linear_mps_kernel_meta(
113127
const Tensor& A,
@@ -169,6 +183,20 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
169183
"_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor");
170184
m.def(
171185
"_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor");
186+
m.def(
187+
"_linear_fp_act_1bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
188+
m.def(
189+
"_linear_fp_act_2bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
190+
m.def(
191+
"_linear_fp_act_3bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
192+
m.def(
193+
"_linear_fp_act_4bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
194+
m.def(
195+
"_linear_fp_act_5bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
196+
m.def(
197+
"_linear_fp_act_6bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
198+
m.def(
199+
"_linear_fp_act_7bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)");
172200
}
173201

174202
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
@@ -189,6 +217,13 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
189217
m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel<5>);
190218
m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel<6>);
191219
m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>);
220+
m.impl("_linear_fp_act_1bit_weight.out", &linear_mps_kernel_out<1>);
221+
m.impl("_linear_fp_act_2bit_weight.out", &linear_mps_kernel_out<2>);
222+
m.impl("_linear_fp_act_3bit_weight.out", &linear_mps_kernel_out<3>);
223+
m.impl("_linear_fp_act_4bit_weight.out", &linear_mps_kernel_out<4>);
224+
m.impl("_linear_fp_act_5bit_weight.out", &linear_mps_kernel_out<5>);
225+
m.impl("_linear_fp_act_6bit_weight.out", &linear_mps_kernel_out<6>);
226+
m.impl("_linear_fp_act_7bit_weight.out", &linear_mps_kernel_out<7>);
192227
}
193228

194229
TORCH_LIBRARY_IMPL(torchao, Meta, m) {

0 commit comments

Comments
 (0)