diff --git a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py index eea7e42666..7764c0871f 100644 --- a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py +++ b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py @@ -37,9 +37,9 @@ */ #ifdef USE_ATEN -using namespace at::native::mps; +using at::native::mps::MetalShaderLibrary; #else -#include +#include #endif static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT( diff --git a/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h b/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h new file mode 100644 index 0000000000..3aca35e699 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h @@ -0,0 +1,64 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +class MetalShaderLibrary { + public: + MetalShaderLibrary(const std::string& src) : shaderSource(src) { + lib = compileLibraryFromSource(shaderSource); + } + MetalShaderLibrary(const MetalShaderLibrary&) = delete; + MetalShaderLibrary(MetalShaderLibrary&&) = delete; + + id getPipelineStateForFunc( + const std::string& fname) { + id func = loadFunc(fname); + + NSError* error = nil; + id device = get_metal_device(); + auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; + if (cpl == nil) { + throw std::runtime_error( + "Failed to construct pipeline state: " + + std::string(error.description.UTF8String)); + } + return cpl; + + } + + private: + std::string shaderSource; + id lib = nil; + + id loadFunc(const std::string& func_name) const { + id func = [lib + newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; + if (func == nil) { + throw std::runtime_error("Can't get function:" + func_name); + } + return func; + } + + id compileLibraryFromSource( + const std::string& source) { + NSError* error = nil; + MTLCompileOptions* options = [MTLCompileOptions new]; + [options setLanguageVersion:MTLLanguageVersion3_1]; + NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; + id device = get_metal_device(); + id library = [device newLibraryWithSource:kernel_source + options:options + error:&error]; + if (library == nil) { + throw std::runtime_error( + "Failed to compile: " + std::string(error.description.UTF8String)); + } + return library; + } +}; diff --git a/torchao/experimental/kernels/mps/src/OperationUtils.h b/torchao/experimental/kernels/mps/src/OperationUtils.h index 7cb902f23f..5a41b264af 100644 --- a/torchao/experimental/kernels/mps/src/OperationUtils.h +++ b/torchao/experimental/kernels/mps/src/OperationUtils.h @@ -6,101 +6,12 @@ #pragma once -#include -#include - -static void throw_exception(const std::string& str) { - std::cerr << str << std::endl; - throw std::runtime_error(str); -} - -inline void dispatch_block( - [[maybe_unused]] id queue, - void (^block)()) { - __block std::optional block_exception; - try { - block(); - } catch (...) { - block_exception = std::current_exception(); - } - if (block_exception) { - std::rethrow_exception(*block_exception); - } -} - -inline id getMetalDevice() { - @autoreleasepool { - NSArray* devices = [MTLCopyAllDevices() autorelease]; - if (devices.count == 0) { - throw_exception("Metal is not supported"); - } - return devices[0]; - } -} - -static id MTL_DEVICE = getMetalDevice(); - -static id compileLibraryFromSource( - id device, - const std::string& source) { - NSError* error = nil; - MTLCompileOptions* options = [MTLCompileOptions new]; - [options setLanguageVersion:MTLLanguageVersion3_1]; - NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; - id library = [device newLibraryWithSource:kernel_source - options:options - error:&error]; - if (library == nil) { - throw_exception( - "Failed to compile: " + std::string(error.description.UTF8String)); - } - return library; -} - -class MetalShaderLibrary { - public: - MetalShaderLibrary(const std::string& src) : shaderSource(src) { - lib = compileLibraryFromSource(device, shaderSource); - } - MetalShaderLibrary(const MetalShaderLibrary&) = delete; - MetalShaderLibrary(MetalShaderLibrary&&) = delete; - - id getPipelineStateForFunc( - const std::string& fname) { - return get_compute_pipeline_state(load_func(fname)); - } - - private: - std::string shaderSource; - id device = MTL_DEVICE; - id lib = nil; - - id load_func(const std::string& func_name) const { - id func = [lib - newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; - if (func == nil) { - throw_exception("Can't get function:" + func_name); - } - return func; - } - - id get_compute_pipeline_state( - id func) const { - NSError* error = nil; - auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; - if (cpl == nil) { - throw_exception( - "Failed to construct pipeline state: " + - std::string(error.description.UTF8String)); - } - return cpl; - } -}; +id getMetalDevice(); class MPSStream { public: MPSStream() { - _commandQueue = [MTL_DEVICE newCommandQueue]; + _commandQueue = [getMetalDevice() newCommandQueue]; } ~MPSStream() { @@ -136,14 +47,6 @@ class MPSStream { id _commandEncoder = nil; }; -inline void finalize_block(MPSStream* mpsStream) { - id encoder = mpsStream->commandEncoder(); - id cmdBuffer = mpsStream->commandBuffer(); - [encoder endEncoding]; - [cmdBuffer commit]; - [cmdBuffer waitUntilCompleted]; -} - inline MPSStream* getCurrentMPSStream() { return new MPSStream(); } diff --git a/torchao/experimental/kernels/mps/src/OperationUtils.mm b/torchao/experimental/kernels/mps/src/OperationUtils.mm new file mode 100644 index 0000000000..795c93225a --- /dev/null +++ b/torchao/experimental/kernels/mps/src/OperationUtils.mm @@ -0,0 +1,20 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +id getMetalDevice() { + @autoreleasepool { + NSArray* devices = [MTLCopyAllDevices() autorelease]; + if (devices.count == 0) { + throw std::runtime_error("Metal is not supported"); + } + static id MTL_DEVICE = devices[0]; + return MTL_DEVICE; + } +} diff --git a/torchao/experimental/kernels/mps/src/common.h b/torchao/experimental/kernels/mps/src/common.h new file mode 100644 index 0000000000..0710d37b3a --- /dev/null +++ b/torchao/experimental/kernels/mps/src/common.h @@ -0,0 +1,51 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#ifdef USE_ATEN +#include +using namespace at::native::mps; +#elif defined(USE_EXECUTORCH) +#include +using namespace executorch::backends::mps::delegate; +#else +#include +#endif + +inline void dispatch_block( + MPSStream* mpsStream, + void (^block)()) { +#if defined(USE_ATEN) + dispatch_sync_with_rethrow(mpsStream->queue(), block); +#elif defined(USE_EXECUTORCH) + dispatch_sync(mpsStream->queue(), block); +#else + (void)mpsStream; + block(); +#endif +} + +inline void optionally_wait_for_command_completion(MPSStream* mpsStream) { +#if defined(USE_ATEN) +#elif defined(USE_EXECUTORCH) + ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok); +#else + id encoder = mpsStream->commandEncoder(); + id cmdBuffer = mpsStream->commandBuffer(); + [encoder endEncoding]; + [cmdBuffer commit]; + [cmdBuffer waitUntilCompleted]; +#endif +} + +inline id get_metal_device() { +#if defined(USE_ATEN) || defined(USE_EXECUTORCH) + return MPSDevice::getInstance()->device(); +#else + return getMetalDevice(); +#endif +} diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h index d37001350a..ae3951e217 100644 --- a/torchao/experimental/kernels/mps/src/lowbit.h +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -9,24 +9,11 @@ #include #include +#include #include -#include +#include // metal_lowbit_quantized_lib #include -#include -#include -#include - -#ifdef USE_ATEN -#include -using namespace at::native::mps; -inline void finalize_block(MPSStream* mpsStream) {} -void (*dispatch_block)(dispatch_queue_t, dispatch_block_t) = - dispatch_sync_with_rethrow; -#else -#include -#endif - namespace torchao::kernels::mps::lowbit { namespace { @@ -103,7 +90,7 @@ inline void linear_lowbit_quant_weights_mps_impl( 0}; MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_block(mpsStream->queue(), ^() { + dispatch_block(mpsStream, ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); id cpl = @@ -119,7 +106,7 @@ inline void linear_lowbit_quant_weights_mps_impl( length:sizeof(uint32_t) * sizes.size() atIndex:5]; dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K); - finalize_block(mpsStream); + optionally_wait_for_command_completion(mpsStream); } }); } diff --git a/torchao/experimental/kernels/mps/test/Makefile b/torchao/experimental/kernels/mps/test/Makefile index e8213818c5..3c0da54f7c 100644 --- a/torchao/experimental/kernels/mps/test/Makefile +++ b/torchao/experimental/kernels/mps/test/Makefile @@ -1,7 +1,7 @@ all: test_lowbit -test_lowbit: test_lowbit.mm - clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $< -framework Metal -framework Foundation +test_lowbit: test_lowbit.mm ../src/OperationUtils.mm + clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $^ -framework Metal -framework Foundation run: test_lowbit ./test_lowbit diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm index 2d86223034..7fb20d254a 100644 --- a/torchao/experimental/kernels/mps/test/test_lowbit.mm +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -31,7 +31,7 @@ id rc = [device newBufferWithLength:length options:MTLResourceStorageModeShared]; if (rc == nil) { - throw_exception( + throw std::runtime_error( "Can't allocate " + std::to_string(length) + " bytes on GPU"); } return rc; @@ -80,7 +80,7 @@ void reference_linear_lowbit_quant_weights_cpu( : M(m), K(k), N(n), qGroupSize(group_size) {} void init() { - allocBuffers(MTL_DEVICE); + allocBuffers(getMetalDevice()); T* a_ptr = reinterpret_cast([buf_A contents]); uint8_t* w_ptr = reinterpret_cast([buf_W contents]); diff --git a/torchao/experimental/ops/mps/CMakeLists.txt b/torchao/experimental/ops/mps/CMakeLists.txt index 044433ef95..1d41f75854 100644 --- a/torchao/experimental/ops/mps/CMakeLists.txt +++ b/torchao/experimental/ops/mps/CMakeLists.txt @@ -26,10 +26,13 @@ endif() find_package(Torch REQUIRED) # Generate metal_shader_lib.h by running gen_metal_shader_lib.py +set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal) +set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py) set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h) add_custom_command( OUTPUT ${GENERATED_METAL_SHADER_LIB} - COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB} + COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB} + DEPENDS ${METAL_SHADERS_DIR}/*.metal ${GEN_SCRIPT} COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py" ) add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB}) @@ -41,7 +44,7 @@ message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") include_directories(${TORCHAO_INCLUDE_DIRS}) include_directories(${CMAKE_INSTALL_PREFIX}/include) -add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm) +add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten OBJECT linear_fp_act_xbit_weight_aten.mm) add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib) target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") @@ -53,8 +56,25 @@ find_library(METAL_LIB Metal) find_library(FOUNDATION_LIB Foundation) target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB}) -install( - TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten - EXPORT _targets - DESTINATION lib +add_library(torchao_ops_mps_aten SHARED) +target_link_libraries(torchao_ops_mps_aten PRIVATE + torchao_ops_mps_linear_fp_act_xbit_weight_aten ) +install(TARGETS torchao_ops_mps_aten DESTINATION lib) + +if(TORCHAO_BUILD_EXECUTORCH_OPS) + include_directories(${CMAKE_INSTALL_PREFIX}/../..) + include_directories(${CMAKE_INSTALL_PREFIX}/schema/include) + include_directories(${CMAKE_INSTALL_PREFIX}/../third-party/flatbuffers/include) + add_library(torchao_ops_mps_linear_fp_act_xbit_weight_executorch OBJECT linear_fp_act_xbit_weight_executorch.mm) + add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_executorch generated_metal_shader_lib) + target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1) + target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE executorch executorch_core mpsdelegate) + target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE ${METAL_LIB} ${FOUNDATION_LIB}) + + add_library(torchao_ops_mps_executorch STATIC) + target_link_libraries(torchao_ops_mps_executorch PRIVATE + torchao_ops_mps_linear_fp_act_xbit_weight_executorch + ) + install(TARGETS torchao_ops_mps_executorch DESTINATION lib) +endif() diff --git a/torchao/experimental/ops/mps/aten/register.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm similarity index 78% rename from torchao/experimental/ops/mps/aten/register.mm rename to torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm index 92a3ba89f0..e11e55c5a0 100644 --- a/torchao/experimental/ops/mps/aten/register.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm @@ -70,12 +70,13 @@ void check_linear_mps_args( } template -Tensor linear_mps_kernel( +Tensor linear_mps_kernel_out( const Tensor& A, const Tensor& B, int64_t group_size, const Tensor& S, - const Tensor& Z) { + const Tensor& Z, + Tensor& C) { TORCH_CHECK( A.is_mps(), __func__, ": A is on ", A.device(), " but expected on mps"); TORCH_CHECK( @@ -84,6 +85,8 @@ Tensor linear_mps_kernel( S.is_mps(), __func__, ": S is on ", S.device(), " but expected on mps"); TORCH_CHECK( Z.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps"); + TORCH_CHECK( + C.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps"); check_linear_mps_args(A, B, group_size, S, Z); @@ -91,8 +94,6 @@ Tensor linear_mps_kernel( auto N = B.size(0); auto K = A.size(1); - auto C = at::empty({M, N}, A.options()); - LowBitQuantWeights::linear( getMTLBufferStorage(A), getMTLBufferStorage(B), @@ -108,6 +109,19 @@ Tensor linear_mps_kernel( return C; } +template +Tensor linear_mps_kernel( + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z) { + auto M = A.size(0); + auto N = B.size(0); + auto C = at::empty({M, N}, A.options()); + return linear_mps_kernel_out(A, B, group_size, S, Z, C); +} + template Tensor linear_mps_kernel_meta( const Tensor& A, @@ -169,6 +183,20 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { "_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); m.def( "_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); + m.def( + "_linear_fp_act_1bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_2bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_3bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_4bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_5bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_6bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_7bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -189,6 +217,13 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel<5>); m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel<6>); m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>); + m.impl("_linear_fp_act_1bit_weight.out", &linear_mps_kernel_out<1>); + m.impl("_linear_fp_act_2bit_weight.out", &linear_mps_kernel_out<2>); + m.impl("_linear_fp_act_3bit_weight.out", &linear_mps_kernel_out<3>); + m.impl("_linear_fp_act_4bit_weight.out", &linear_mps_kernel_out<4>); + m.impl("_linear_fp_act_5bit_weight.out", &linear_mps_kernel_out<5>); + m.impl("_linear_fp_act_6bit_weight.out", &linear_mps_kernel_out<6>); + m.impl("_linear_fp_act_7bit_weight.out", &linear_mps_kernel_out<7>); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm new file mode 100644 index 0000000000..2892a67245 --- /dev/null +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm @@ -0,0 +1,138 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::backends::mps::delegate::getMTLBufferStorage; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::tensor_is_rank; + +namespace { + +std::string scalar_type_to_string(const ScalarType& scalar_type) { + switch (scalar_type) { + case ScalarType::Float: + return "float"; + case ScalarType::Half: + return "half"; + case ScalarType::BFloat16: + return "bfloat"; + default: + ET_CHECK_MSG( + false, "Unsupported type by lowbit quantized linear"); + return "undefined"; + } +} + +template +bool check_linear_mps_args( + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z) { + auto N = B.size(0); + auto K = A.size(1); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + A.scalar_type() == ScalarType::BFloat16 || + A.scalar_type() == ScalarType::Half || + A.scalar_type() == ScalarType::Float, + "Expect A to be either 32-bit or 16-bit float tensor."); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + tensor_is_rank(A, 2), "Expect A to be 2D tensor."); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + B.scalar_type() == ScalarType::Byte, "Expect B to be uint8 tensor."); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + B.size(1) == (K / 8) * nbit, "Expect B.size(1) == (K / 8) * nbit"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE(K % 8 == 0, "Expect K to be multiple of 8"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + group_size == 32 || group_size == 64 || group_size == 128 || + group_size == 256, + "Expect group_size to be 32, 64, 128 or 256"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + S.dim() == 2 && S.size(1) == N, + "Expect S to be 2d tensor with shape [:, N]"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + Z.dim() == 2 && Z.size(1) == N, + "Expect Z to be 2d tensor with shape [:, N]"); + + return true; +} + +template +Tensor& linear_mps_kernel_et_ctx_out( + KernelRuntimeContext& ctx, + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z, + Tensor& out) { + ET_KERNEL_CHECK( + ctx, + check_linear_mps_args(A, B, group_size, S, Z), + InvalidArgument, + out); + + auto M = A.size(0); + auto N = B.size(0); + auto K = A.size(1); + + torchao::kernels::mps::lowbit::LowBitQuantWeights::linear( + getMTLBufferStorage(A), + getMTLBufferStorage(B), + group_size, + getMTLBufferStorage(S), + getMTLBufferStorage(Z), + getMTLBufferStorage(out), + M, + K, + N, + scalar_type_to_string(A.scalar_type())); + + return out; +} + +} // namespace + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_1bit_weight.out", linear_mps_kernel_et_ctx_out<1>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_2bit_weight.out", linear_mps_kernel_et_ctx_out<2>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_3bit_weight.out", linear_mps_kernel_et_ctx_out<3>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_4bit_weight.out", linear_mps_kernel_et_ctx_out<4>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_5bit_weight.out", linear_mps_kernel_et_ctx_out<5>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_6bit_weight.out", linear_mps_kernel_et_ctx_out<6>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_7bit_weight.out", linear_mps_kernel_et_ctx_out<7>); +} diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index f4c460a368..acff5624c8 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -11,7 +11,7 @@ from parameterized import parameterized -libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libname = "libtorchao_ops_mps_aten.dylib" libpath = os.path.abspath( os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) ) diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index 00c08738c2..5b3331c6a8 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -17,7 +17,7 @@ from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer from torchao.experimental.quant_api import _quantize -libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libname = "libtorchao_ops_mps_aten.dylib" libpath = os.path.abspath( os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) ) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index be72a59aab..0904d1d174 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -469,21 +469,19 @@ def quantize(self, model: nn.Module) -> nn.Module: return model -from torchao.experimental._linear_8bit_act_xbit_weight_layout import Linear8BitActXBitWeightLayout -from torchao.quantization.quant_api import ( - _get_linear_subclass_inserter, - MappingType, - to_affine_quantized_intx, - ZeroPointDomain, -) - - def int8_dynamic_activation_intx_weight( group_size: int = 128, nbit: int = 4, has_weight_zeros: bool = False, target: str = "native", ): + from torchao.experimental._linear_8bit_act_xbit_weight_layout import Linear8BitActXBitWeightLayout + from torchao.quantization.quant_api import ( + _get_linear_subclass_inserter, + MappingType, + to_affine_quantized_intx, + ZeroPointDomain, + ) def apply(weight): assert weight.shape[-1] % group_size == 0 @@ -541,10 +539,11 @@ def quantize_and_pack_weights(self, weights, nbit, group_size): ) weight_scales = torch.transpose_copy(weight_scales, 1, 0) weight_zeros = torch.transpose_copy(weight_zeros, 1, 0) - self.weight_scales = weight_scales - self.weight_zeros = -weight_zeros * weight_scales - - self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps") + weight_zeros = -weight_zeros * weight_scales + self.weight_scales = nn.Parameter(weight_scales, requires_grad=False) + self.weight_zeros = nn.Parameter(weight_zeros, requires_grad=False) + packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps") + self.packed_weights = nn.Parameter(packed_weights, requires_grad=False) def forward(self, x): assert x.dim() >= 2