Skip to content

Commit 4351c41

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
metal lowbit kernels: executorch ops (#1322)
Summary: Pull Request resolved: #1322 Refactors kernels/mps/src/OperationUntils.h, moving MetalShaderLibrary into its own header. Integrates MPS delegate functions into lowbit.h Registers out variants for the ATen ops Registers ET ops Differential Revision: D65957345
1 parent 31234db commit 4351c41

14 files changed

+364
-147
lines changed

Diff for: torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
*/
3838
3939
#ifdef USE_ATEN
40-
using namespace at::native::mps;
40+
using at::native::mps::MetalShaderLibrary;
4141
#else
42-
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
42+
#include <torchao/experimental/kernels/mps/src/MetalShaderLibrary.h>
4343
#endif
4444
4545
static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT(
+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <torchao/experimental/kernels/mps/src/common.h>
10+
11+
class MetalShaderLibrary {
12+
public:
13+
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
14+
lib = compileLibraryFromSource(shaderSource);
15+
}
16+
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
17+
MetalShaderLibrary(MetalShaderLibrary&&) = delete;
18+
19+
id<MTLComputePipelineState> getPipelineStateForFunc(
20+
const std::string& fname) {
21+
id<MTLFunction> func = loadFunc(fname);
22+
23+
NSError* error = nil;
24+
id<MTLDevice> device = get_metal_device();
25+
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
26+
if (cpl == nil) {
27+
throw std::runtime_error(
28+
"Failed to construct pipeline state: " +
29+
std::string(error.description.UTF8String));
30+
}
31+
return cpl;
32+
33+
}
34+
35+
private:
36+
std::string shaderSource;
37+
id<MTLLibrary> lib = nil;
38+
39+
id<MTLFunction> loadFunc(const std::string& func_name) const {
40+
id<MTLFunction> func = [lib
41+
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
42+
if (func == nil) {
43+
throw std::runtime_error("Can't get function:" + func_name);
44+
}
45+
return func;
46+
}
47+
48+
id<MTLLibrary> compileLibraryFromSource(
49+
const std::string& source) {
50+
NSError* error = nil;
51+
MTLCompileOptions* options = [MTLCompileOptions new];
52+
[options setLanguageVersion:MTLLanguageVersion3_1];
53+
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
54+
id<MTLDevice> device = get_metal_device();
55+
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
56+
options:options
57+
error:&error];
58+
if (library == nil) {
59+
throw std::runtime_error(
60+
"Failed to compile: " + std::string(error.description.UTF8String));
61+
}
62+
return library;
63+
}
64+
};

Diff for: torchao/experimental/kernels/mps/src/OperationUtils.h

+2-99
Original file line numberDiff line numberDiff line change
@@ -6,101 +6,12 @@
66

77
#pragma once
88

9-
#include <iostream>
10-
#include <stdexcept>
11-
12-
static void throw_exception(const std::string& str) {
13-
std::cerr << str << std::endl;
14-
throw std::runtime_error(str);
15-
}
16-
17-
inline void dispatch_block(
18-
[[maybe_unused]] id<MTLCommandQueue> queue,
19-
void (^block)()) {
20-
__block std::optional<std::exception_ptr> block_exception;
21-
try {
22-
block();
23-
} catch (...) {
24-
block_exception = std::current_exception();
25-
}
26-
if (block_exception) {
27-
std::rethrow_exception(*block_exception);
28-
}
29-
}
30-
31-
inline id<MTLDevice> getMetalDevice() {
32-
@autoreleasepool {
33-
NSArray* devices = [MTLCopyAllDevices() autorelease];
34-
if (devices.count == 0) {
35-
throw_exception("Metal is not supported");
36-
}
37-
return devices[0];
38-
}
39-
}
40-
41-
static id<MTLDevice> MTL_DEVICE = getMetalDevice();
42-
43-
static id<MTLLibrary> compileLibraryFromSource(
44-
id<MTLDevice> device,
45-
const std::string& source) {
46-
NSError* error = nil;
47-
MTLCompileOptions* options = [MTLCompileOptions new];
48-
[options setLanguageVersion:MTLLanguageVersion3_1];
49-
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
50-
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
51-
options:options
52-
error:&error];
53-
if (library == nil) {
54-
throw_exception(
55-
"Failed to compile: " + std::string(error.description.UTF8String));
56-
}
57-
return library;
58-
}
59-
60-
class MetalShaderLibrary {
61-
public:
62-
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
63-
lib = compileLibraryFromSource(device, shaderSource);
64-
}
65-
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
66-
MetalShaderLibrary(MetalShaderLibrary&&) = delete;
67-
68-
id<MTLComputePipelineState> getPipelineStateForFunc(
69-
const std::string& fname) {
70-
return get_compute_pipeline_state(load_func(fname));
71-
}
72-
73-
private:
74-
std::string shaderSource;
75-
id<MTLDevice> device = MTL_DEVICE;
76-
id<MTLLibrary> lib = nil;
77-
78-
id<MTLFunction> load_func(const std::string& func_name) const {
79-
id<MTLFunction> func = [lib
80-
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
81-
if (func == nil) {
82-
throw_exception("Can't get function:" + func_name);
83-
}
84-
return func;
85-
}
86-
87-
id<MTLComputePipelineState> get_compute_pipeline_state(
88-
id<MTLFunction> func) const {
89-
NSError* error = nil;
90-
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
91-
if (cpl == nil) {
92-
throw_exception(
93-
"Failed to construct pipeline state: " +
94-
std::string(error.description.UTF8String));
95-
}
96-
return cpl;
97-
}
98-
};
9+
id<MTLDevice> getMetalDevice();
9910

10011
class MPSStream {
10112
public:
10213
MPSStream() {
103-
_commandQueue = [MTL_DEVICE newCommandQueue];
14+
_commandQueue = [getMetalDevice() newCommandQueue];
10415
}
10516

10617
~MPSStream() {
@@ -136,14 +47,6 @@ class MPSStream {
13647
id<MTLComputeCommandEncoder> _commandEncoder = nil;
13748
};
13849

139-
inline void finalize_block(MPSStream* mpsStream) {
140-
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
141-
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
142-
[encoder endEncoding];
143-
[cmdBuffer commit];
144-
[cmdBuffer waitUntilCompleted];
145-
}
146-
14750
inline MPSStream* getCurrentMPSStream() {
14851
return new MPSStream();
14952
}
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <Metal/Metal.h>
8+
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
9+
#include <stdexcept>
10+
11+
id<MTLDevice> getMetalDevice() {
12+
@autoreleasepool {
13+
NSArray* devices = [MTLCopyAllDevices() autorelease];
14+
if (devices.count == 0) {
15+
throw std::runtime_error("Metal is not supported");
16+
}
17+
static id<MTLDevice> MTL_DEVICE = devices[0];
18+
return MTL_DEVICE;
19+
}
20+
}

Diff for: torchao/experimental/kernels/mps/src/common.h

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#ifdef USE_ATEN
10+
#include <ATen/native/mps/OperationUtils.h>
11+
using namespace at::native::mps;
12+
#elif defined(USE_EXECUTORCH)
13+
#include <executorch/backends/apple/mps/runtime/MPSStream.h>
14+
using namespace executorch::backends::mps::delegate;
15+
#else
16+
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
17+
#endif
18+
19+
inline void dispatch_block(
20+
MPSStream* mpsStream,
21+
void (^block)()) {
22+
#if defined(USE_ATEN)
23+
dispatch_sync_with_rethrow(mpsStream->queue(), block);
24+
#elif defined(USE_EXECUTORCH)
25+
dispatch_sync(mpsStream->queue(), block);
26+
#else
27+
(void)mpsStream;
28+
block();
29+
#endif
30+
}
31+
32+
inline void optionally_wait_for_command_completion(MPSStream* mpsStream) {
33+
#if defined(USE_ATEN)
34+
#elif defined(USE_EXECUTORCH)
35+
ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok);
36+
#else
37+
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
38+
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
39+
[encoder endEncoding];
40+
[cmdBuffer commit];
41+
[cmdBuffer waitUntilCompleted];
42+
#endif
43+
}
44+
45+
inline id<MTLDevice> get_metal_device() {
46+
#if defined(USE_ATEN) || defined(USE_EXECUTORCH)
47+
return MPSDevice::getInstance()->device();
48+
#else
49+
return getMetalDevice();
50+
#endif
51+
}

Diff for: torchao/experimental/kernels/mps/src/lowbit.h

+4-17
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,11 @@
99
#include <Metal/Metal.h>
1010
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
1111

12+
#include <torchao/experimental/kernels/mps/src/common.h>
1213
#include <torchao/experimental/kernels/mps/src/dispatch.h>
13-
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h>
14+
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h> // metal_lowbit_quantized_lib
1415
#include <torchao/experimental/kernels/mps/src/packing.h>
1516

16-
#include <cassert>
17-
#include <fstream>
18-
#include <sstream>
19-
20-
#ifdef USE_ATEN
21-
#include <ATen/native/mps/OperationUtils.h>
22-
using namespace at::native::mps;
23-
inline void finalize_block(MPSStream* mpsStream) {}
24-
void (*dispatch_block)(dispatch_queue_t, dispatch_block_t) =
25-
dispatch_sync_with_rethrow;
26-
#else
27-
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
28-
#endif
29-
3017
namespace torchao::kernels::mps::lowbit {
3118
namespace {
3219

@@ -103,7 +90,7 @@ inline void linear_lowbit_quant_weights_mps_impl(
10390
0};
10491

10592
MPSStream* mpsStream = getCurrentMPSStream();
106-
dispatch_block(mpsStream->queue(), ^() {
93+
dispatch_block(mpsStream, ^() {
10794
@autoreleasepool {
10895
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
10996
id<MTLComputePipelineState> cpl =
@@ -119,7 +106,7 @@ inline void linear_lowbit_quant_weights_mps_impl(
119106
length:sizeof(uint32_t) * sizes.size()
120107
atIndex:5];
121108
dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K);
122-
finalize_block(mpsStream);
109+
optionally_wait_for_command_completion(mpsStream);
123110
}
124111
});
125112
}

Diff for: torchao/experimental/kernels/mps/test/Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
all: test_lowbit
22

3-
test_lowbit: test_lowbit.mm
4-
clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $< -framework Metal -framework Foundation
3+
test_lowbit: test_lowbit.mm ../src/OperationUtils.mm
4+
clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $^ -framework Metal -framework Foundation
55

66
run: test_lowbit
77
./test_lowbit

Diff for: torchao/experimental/kernels/mps/test/test_lowbit.mm

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
id<MTLBuffer> rc = [device newBufferWithLength:length
3232
options:MTLResourceStorageModeShared];
3333
if (rc == nil) {
34-
throw_exception(
34+
throw std::runtime_error(
3535
"Can't allocate " + std::to_string(length) + " bytes on GPU");
3636
}
3737
return rc;
@@ -80,7 +80,7 @@ void reference_linear_lowbit_quant_weights_cpu(
8080
: M(m), K(k), N(n), qGroupSize(group_size) {}
8181

8282
void init() {
83-
allocBuffers(MTL_DEVICE);
83+
allocBuffers(getMetalDevice());
8484

8585
T* a_ptr = reinterpret_cast<T*>([buf_A contents]);
8686
uint8_t* w_ptr = reinterpret_cast<uint8_t*>([buf_W contents]);

Diff for: torchao/experimental/ops/mps/CMakeLists.txt

+26-6
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,13 @@ endif()
2626
find_package(Torch REQUIRED)
2727

2828
# Generate metal_shader_lib.h by running gen_metal_shader_lib.py
29+
set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal)
30+
set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py)
2931
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
3032
add_custom_command(
3133
OUTPUT ${GENERATED_METAL_SHADER_LIB}
32-
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB}
34+
COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB}
35+
DEPENDS ${METAL_SHADERS_DIR}/*.metal ${GEN_SCRIPT}
3336
COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py"
3437
)
3538
add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB})
@@ -41,7 +44,7 @@ message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
4144

4245
include_directories(${TORCHAO_INCLUDE_DIRS})
4346
include_directories(${CMAKE_INSTALL_PREFIX}/include)
44-
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm)
47+
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten OBJECT linear_fp_act_xbit_weight_aten.mm)
4548
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib)
4649

4750
target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
@@ -53,8 +56,25 @@ find_library(METAL_LIB Metal)
5356
find_library(FOUNDATION_LIB Foundation)
5457
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})
5558

56-
install(
57-
TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten
58-
EXPORT _targets
59-
DESTINATION lib
59+
add_library(torchao_ops_mps_aten SHARED)
60+
target_link_libraries(torchao_ops_mps_aten PRIVATE
61+
torchao_ops_mps_linear_fp_act_xbit_weight_aten
6062
)
63+
install(TARGETS torchao_ops_mps_aten DESTINATION lib)
64+
65+
if(TORCHAO_BUILD_EXECUTORCH_OPS)
66+
include_directories(${CMAKE_INSTALL_PREFIX}/../..)
67+
include_directories(${CMAKE_INSTALL_PREFIX}/schema/include)
68+
include_directories(${CMAKE_INSTALL_PREFIX}/../third-party/flatbuffers/include)
69+
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_executorch OBJECT linear_fp_act_xbit_weight_executorch.mm)
70+
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_executorch generated_metal_shader_lib)
71+
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1)
72+
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE executorch executorch_core mpsdelegate)
73+
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})
74+
75+
add_library(torchao_ops_mps_executorch STATIC)
76+
target_link_libraries(torchao_ops_mps_executorch PRIVATE
77+
torchao_ops_mps_linear_fp_act_xbit_weight_executorch
78+
)
79+
install(TARGETS torchao_ops_mps_executorch DESTINATION lib)
80+
endif()

0 commit comments

Comments
 (0)