Skip to content

metal lowbit kernels: executorch ops #1322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
*/
#ifdef USE_ATEN
using namespace at::native::mps;
using at::native::mps::MetalShaderLibrary;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#include <torchao/experimental/kernels/mps/src/MetalShaderLibrary.h>
#endif
static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT(
Expand Down
64 changes: 64 additions & 0 deletions torchao/experimental/kernels/mps/src/MetalShaderLibrary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <torchao/experimental/kernels/mps/src/common.h>

class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
lib = compileLibraryFromSource(shaderSource);
}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
MetalShaderLibrary(MetalShaderLibrary&&) = delete;

id<MTLComputePipelineState> getPipelineStateForFunc(
const std::string& fname) {
id<MTLFunction> func = loadFunc(fname);

NSError* error = nil;
id<MTLDevice> device = get_metal_device();
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
if (cpl == nil) {
throw std::runtime_error(
"Failed to construct pipeline state: " +
std::string(error.description.UTF8String));
}
return cpl;

}

private:
std::string shaderSource;
id<MTLLibrary> lib = nil;

id<MTLFunction> loadFunc(const std::string& func_name) const {
id<MTLFunction> func = [lib
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
if (func == nil) {
throw std::runtime_error("Can't get function:" + func_name);
}
return func;
}

id<MTLLibrary> compileLibraryFromSource(
const std::string& source) {
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:MTLLanguageVersion3_1];
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
id<MTLDevice> device = get_metal_device();
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
options:options
error:&error];
if (library == nil) {
throw std::runtime_error(
"Failed to compile: " + std::string(error.description.UTF8String));
}
return library;
}
};
101 changes: 2 additions & 99 deletions torchao/experimental/kernels/mps/src/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,101 +6,12 @@

#pragma once

#include <iostream>
#include <stdexcept>

static void throw_exception(const std::string& str) {
std::cerr << str << std::endl;
throw std::runtime_error(str);
}

inline void dispatch_block(
[[maybe_unused]] id<MTLCommandQueue> queue,
void (^block)()) {
__block std::optional<std::exception_ptr> block_exception;
try {
block();
} catch (...) {
block_exception = std::current_exception();
}
if (block_exception) {
std::rethrow_exception(*block_exception);
}
}

inline id<MTLDevice> getMetalDevice() {
@autoreleasepool {
NSArray* devices = [MTLCopyAllDevices() autorelease];
if (devices.count == 0) {
throw_exception("Metal is not supported");
}
return devices[0];
}
}

static id<MTLDevice> MTL_DEVICE = getMetalDevice();

static id<MTLLibrary> compileLibraryFromSource(
id<MTLDevice> device,
const std::string& source) {
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:MTLLanguageVersion3_1];
NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()];
id<MTLLibrary> library = [device newLibraryWithSource:kernel_source
options:options
error:&error];
if (library == nil) {
throw_exception(
"Failed to compile: " + std::string(error.description.UTF8String));
}
return library;
}

class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src) : shaderSource(src) {
lib = compileLibraryFromSource(device, shaderSource);
}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
MetalShaderLibrary(MetalShaderLibrary&&) = delete;

id<MTLComputePipelineState> getPipelineStateForFunc(
const std::string& fname) {
return get_compute_pipeline_state(load_func(fname));
}

private:
std::string shaderSource;
id<MTLDevice> device = MTL_DEVICE;
id<MTLLibrary> lib = nil;

id<MTLFunction> load_func(const std::string& func_name) const {
id<MTLFunction> func = [lib
newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
if (func == nil) {
throw_exception("Can't get function:" + func_name);
}
return func;
}

id<MTLComputePipelineState> get_compute_pipeline_state(
id<MTLFunction> func) const {
NSError* error = nil;
auto cpl = [device newComputePipelineStateWithFunction:func error:&error];
if (cpl == nil) {
throw_exception(
"Failed to construct pipeline state: " +
std::string(error.description.UTF8String));
}
return cpl;
}
};
id<MTLDevice> getMetalDevice();

class MPSStream {
public:
MPSStream() {
_commandQueue = [MTL_DEVICE newCommandQueue];
_commandQueue = [getMetalDevice() newCommandQueue];
}

~MPSStream() {
Expand Down Expand Up @@ -136,14 +47,6 @@ class MPSStream {
id<MTLComputeCommandEncoder> _commandEncoder = nil;
};

inline void finalize_block(MPSStream* mpsStream) {
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
[encoder endEncoding];
[cmdBuffer commit];
[cmdBuffer waitUntilCompleted];
}

inline MPSStream* getCurrentMPSStream() {
return new MPSStream();
}
20 changes: 20 additions & 0 deletions torchao/experimental/kernels/mps/src/OperationUtils.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <stdexcept>

id<MTLDevice> getMetalDevice() {
@autoreleasepool {
NSArray* devices = [MTLCopyAllDevices() autorelease];
if (devices.count == 0) {
throw std::runtime_error("Metal is not supported");
}
static id<MTLDevice> MTL_DEVICE = devices[0];
return MTL_DEVICE;
}
}
51 changes: 51 additions & 0 deletions torchao/experimental/kernels/mps/src/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#ifdef USE_ATEN
#include <ATen/native/mps/OperationUtils.h>
using namespace at::native::mps;
#elif defined(USE_EXECUTORCH)
#include <executorch/backends/apple/mps/runtime/MPSStream.h>
using namespace executorch::backends::mps::delegate;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#endif

inline void dispatch_block(
MPSStream* mpsStream,
void (^block)()) {
#if defined(USE_ATEN)
dispatch_sync_with_rethrow(mpsStream->queue(), block);
#elif defined(USE_EXECUTORCH)
dispatch_sync(mpsStream->queue(), block);
#else
(void)mpsStream;
block();
#endif
}

inline void optionally_wait_for_command_completion(MPSStream* mpsStream) {
#if defined(USE_ATEN)
#elif defined(USE_EXECUTORCH)
ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok);
#else
id<MTLCommandEncoder> encoder = mpsStream->commandEncoder();
id<MTLCommandBuffer> cmdBuffer = mpsStream->commandBuffer();
[encoder endEncoding];
[cmdBuffer commit];
[cmdBuffer waitUntilCompleted];
#endif
}

inline id<MTLDevice> get_metal_device() {
#if defined(USE_ATEN) || defined(USE_EXECUTORCH)
return MPSDevice::getInstance()->device();
#else
return getMetalDevice();
#endif
}
21 changes: 4 additions & 17 deletions torchao/experimental/kernels/mps/src/lowbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,11 @@
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>

#include <torchao/experimental/kernels/mps/src/common.h>
#include <torchao/experimental/kernels/mps/src/dispatch.h>
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h>
#include <torchao/experimental/kernels/mps/src/metal_shader_lib.h> // metal_lowbit_quantized_lib
#include <torchao/experimental/kernels/mps/src/packing.h>

#include <cassert>
#include <fstream>
#include <sstream>

#ifdef USE_ATEN
#include <ATen/native/mps/OperationUtils.h>
using namespace at::native::mps;
inline void finalize_block(MPSStream* mpsStream) {}
void (*dispatch_block)(dispatch_queue_t, dispatch_block_t) =
dispatch_sync_with_rethrow;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#endif

namespace torchao::kernels::mps::lowbit {
namespace {

Expand Down Expand Up @@ -103,7 +90,7 @@ inline void linear_lowbit_quant_weights_mps_impl(
0};

MPSStream* mpsStream = getCurrentMPSStream();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for PT integration we are using a separate implementation of this? And for executorch we use one from delegate?

Do we not have the same from aten?

Also what is the interaction with device and the stream? It seems that you are creating your own device when compiling for PyTorch, but then use getCurrentMPSStream from aten?

Same question for ET. @malfet thoughts on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for PT integration we are using a separate implementation of this? And for executorch we use one from delegate?

For PT we are using the getCurrentMPSStream implementation from the ATen MPS code, and for ET we are using the getCurrentMPSStream implementation from the ET MPS delegate code.

Also what is the interaction with device and the stream? It seems that you are creating your own device when compiling for PyTorch, but then use getCurrentMPSStream from aten?

I am not creating a device in any of those cases (PT or ET). In those cases I basically just retrieve the device via MPSDevice::getInstance()->device(). It is the same in that respect PT vs ET.

dispatch_block(mpsStream->queue(), ^() {
dispatch_block(mpsStream, ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
id<MTLComputePipelineState> cpl =
Expand All @@ -119,7 +106,7 @@ inline void linear_lowbit_quant_weights_mps_impl(
length:sizeof(uint32_t) * sizes.size()
atIndex:5];
dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K);
finalize_block(mpsStream);
optionally_wait_for_command_completion(mpsStream);
}
});
}
Expand Down
4 changes: 2 additions & 2 deletions torchao/experimental/kernels/mps/test/Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
all: test_lowbit

test_lowbit: test_lowbit.mm
clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $< -framework Metal -framework Foundation
test_lowbit: test_lowbit.mm ../src/OperationUtils.mm
clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $^ -framework Metal -framework Foundation

run: test_lowbit
./test_lowbit
4 changes: 2 additions & 2 deletions torchao/experimental/kernels/mps/test/test_lowbit.mm
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
id<MTLBuffer> rc = [device newBufferWithLength:length
options:MTLResourceStorageModeShared];
if (rc == nil) {
throw_exception(
throw std::runtime_error(
"Can't allocate " + std::to_string(length) + " bytes on GPU");
}
return rc;
Expand Down Expand Up @@ -80,7 +80,7 @@ void reference_linear_lowbit_quant_weights_cpu(
: M(m), K(k), N(n), qGroupSize(group_size) {}

void init() {
allocBuffers(MTL_DEVICE);
allocBuffers(getMetalDevice());

T* a_ptr = reinterpret_cast<T*>([buf_A contents]);
uint8_t* w_ptr = reinterpret_cast<uint8_t*>([buf_W contents]);
Expand Down
32 changes: 26 additions & 6 deletions torchao/experimental/ops/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ endif()
find_package(Torch REQUIRED)

# Generate metal_shader_lib.h by running gen_metal_shader_lib.py
set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal)
set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py)
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
add_custom_command(
OUTPUT ${GENERATED_METAL_SHADER_LIB}
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB}
COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB}
DEPENDS ${METAL_SHADERS_DIR}/*.metal ${GEN_SCRIPT}
COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py"
)
add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB})
Expand All @@ -41,7 +44,7 @@ message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")

include_directories(${TORCHAO_INCLUDE_DIRS})
include_directories(${CMAKE_INSTALL_PREFIX}/include)
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm)
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten OBJECT linear_fp_act_xbit_weight_aten.mm)
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib)

target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
Expand All @@ -53,8 +56,25 @@ find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})

install(
TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten
EXPORT _targets
DESTINATION lib
add_library(torchao_ops_mps_aten SHARED)
target_link_libraries(torchao_ops_mps_aten PRIVATE
torchao_ops_mps_linear_fp_act_xbit_weight_aten
)
install(TARGETS torchao_ops_mps_aten DESTINATION lib)

if(TORCHAO_BUILD_EXECUTORCH_OPS)
include_directories(${CMAKE_INSTALL_PREFIX}/../..)
include_directories(${CMAKE_INSTALL_PREFIX}/schema/include)
include_directories(${CMAKE_INSTALL_PREFIX}/../third-party/flatbuffers/include)
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_executorch OBJECT linear_fp_act_xbit_weight_executorch.mm)
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_executorch generated_metal_shader_lib)
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE executorch executorch_core mpsdelegate)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})

add_library(torchao_ops_mps_executorch STATIC)
target_link_libraries(torchao_ops_mps_executorch PRIVATE
torchao_ops_mps_linear_fp_act_xbit_weight_executorch
)
install(TARGETS torchao_ops_mps_executorch DESTINATION lib)
endif()
Loading
Loading