diff --git a/docs/source/api_ref_sparsity.rst b/docs/source/api_ref_sparsity.rst index 8023d0bacc..33c652390d 100644 --- a/docs/source/api_ref_sparsity.rst +++ b/docs/source/api_ref_sparsity.rst @@ -12,7 +12,7 @@ torchao.sparsity WandaSparsifier PerChannelNormObserver - apply_sparse_semi_structured apply_fake_sparsity - - + sparsify_ + semi_sparse_weight + int8_dynamic_activation_int8_semi_sparse_weight diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu index 2daefb7773..411343f0da 100644 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -29,26 +29,35 @@ namespace torchao { #if defined(BUILD_S8S4_LINEAR_CUTLASS) template< - typename ElementA, - typename ElementAScale, - typename ElementB, - typename ElementBScale, - typename ElementC, - typename ElementAccumulator, - typename ElementEpilogue, - typename ElementOutput, typename ThreadblockShape, typename WarpShape, typename InstructionShape, int NumStages, - bool use_tensor_c> -void s8s4_linear_kernel_cutlass( + typename ElementA, + typename ElementB, + typename ElementAccumulator, + typename Operator, + typename ElementAScale, + typename ElementBScale, + typename ElementC, + typename UseTensorC, + typename ElementOutput> +void s8s4_linear_kernel_cutlass_sm8x( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { + using SmArch = cutlass::arch::Sm80; + using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementEpilogue = float; + + using ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + + constexpr auto NumEVTEpilogueStages = 1; const int m = tensor_a.size(0); const int n = tensor_b.size(0); @@ -56,13 +65,13 @@ void s8s4_linear_kernel_cutlass( constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentAScale = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentBScale = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentOutput = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; // Check for current CUTLASS limitations w.r.t. alignments. TORCH_CHECK(k % AlignmentA == 0, @@ -75,12 +84,6 @@ void s8s4_linear_kernel_cutlass( __func__, " : Number of columns of tensor C must be divisible ", "by ", AlignmentC); - using SmArch = cutlass::arch::Sm80; - using ThreadblockSwizzle = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; - - constexpr auto NumEVTEpilogueStages = 1; - using TensorAScaleTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< ThreadblockShape, @@ -132,9 +135,9 @@ void s8s4_linear_kernel_cutlass( cutlass::epilogue::threadblock::VisitorRowBroadcast< TensorCTileThreadMap, ElementC, - cute::Stride>; + cute::Stride>; using TensorC = - std::conditional_t; + std::conditional_t; using TensorCArguments = typename TensorC::Arguments; using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< @@ -178,7 +181,7 @@ void s8s4_linear_kernel_cutlass( typename cutlass::gemm::kernel::DefaultGemmWithVisitor< ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, - ElementC, LayoutC, AlignmentC, + ElementOutput, LayoutOutput, AlignmentOutput, ElementAccumulator, ElementEpilogue, cutlass::arch::OpClassTensorOp, @@ -189,7 +192,7 @@ void s8s4_linear_kernel_cutlass( EVTOutput, ThreadblockSwizzle, NumStages, - cutlass::arch::OpMultiplyAddMixedInputUpcast, + Operator, NumEVTEpilogueStages >::GemmKernel; @@ -210,7 +213,7 @@ void s8s4_linear_kernel_cutlass( }; TensorCArguments tensor_c_arguments{ [&]() -> TensorCArguments { - if constexpr (use_tensor_c) { + if constexpr (UseTensorC::value) { return {(ElementC*)tensor_c.data_ptr(), ElementC(0), {cute::_0{}, cute::_1{}, problem_size.n()}}; @@ -282,127 +285,193 @@ void s8s4_linear_kernel_cutlass( // Perform mixed datatypes GEMM operation. status = gemm_op.run(at::cuda::getCurrentCUDAStream()); CUTLASS_STATUS_CHECK(status); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template< - typename ElementA, - typename ElementAScale, - typename ElementB, - typename ElementBScale, - typename ElementC, - typename ElementAccumulator, - typename ElementEpilogue, - typename ElementOutput, - bool use_tensor_c> -void -s8s4_linear_cutlass_dispatch_shapes( +template +static void select_config( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm8x = dprops->major == 8; + + if (is_sm8x) { + if constexpr (std::is_same::value && + std::is_same::value) { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + // A minimal heuristic to improve performance for small number + // of inputs cases. + if (tensor_a.size(0) <= 16) { + using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; + constexpr auto NumStages = 6; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else if (tensor_a.size(0) <= 32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + constexpr auto NumStages = 5; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + constexpr auto NumStages = 4; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } + return; + } + } + + TORCH_CHECK(false, + __func__, " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); +} + +template +static void +dispatch_on_tensor_a_and_tensor_b( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - - // A minimal heuristic to improve performance for small number of - // inputs cases. - if (tensor_a.size(0) <= 16) { - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; - constexpr auto NumStages = 6; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else if (tensor_a.size(0) <= 32) { - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; - constexpr auto NumStages = 5; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; - constexpr auto NumStages = 4; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( + if (tensor_a.scalar_type() == at::ScalarType::Char) { + if (tensor_b.scalar_type() == at::ScalarType::Char) { + if (tensor_a.size(1) == 2 * tensor_b.size(1)) { + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; + select_config< + ElementA, ElementB, ElementAccumulator, Operator, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + } + return; + } } + + TORCH_CHECK(false, + __func__, " : Operator not supported for combination of data ", + "types ", tensor_a.scalar_type(), " for first operand and ", + tensor_b.scalar_type(), " for second operand"); } -#endif -// Perform linear operation, using corresponding CUTLASS mixed -// data-types GEMM kernel, to given arguments: -// result = (input * input_scale) @ (weight * weight_scale).T + bias -// Notes: The "input_scale" tensor is expected to be a vector, of size -// equal to number of rows of "input" tensor. The "weight_scale" -// tensor is expected to be a vector, of size equal to number of rows -// of "weight" tensor. The "bias" tensor is expected to be a vector, -// of size equal to number of rows of "weight" tensor. -at::Tensor -s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, - const at::Tensor& weight, const at::Tensor& weight_scale, - const at::Tensor& bias) { -#if defined(BUILD_S8S4_LINEAR_CUTLASS) - // For now, only CC 8.x devices are supported. - const auto dprops = at::cuda::getCurrentDeviceProperties(); - const auto is_sm8x = dprops->major == 8; - TORCH_CHECK(is_sm8x, - __func__, " : Supported only on GPUs with compute capability " - "8.x"); - - // Validate datatypes of arguments. - TORCH_CHECK(input.dtype() == at::kChar, - __func__, " : The input datatype ", input.dtype(), - " not supported"); - TORCH_CHECK(input_scale.dtype() == at::kHalf || - input_scale.dtype() == at::kBFloat16, - __func__, " : The input scale datatype ", input_scale.dtype(), - " not supported"); - TORCH_CHECK(weight.dtype() == at::kChar, " : The weight datatype ", - weight.dtype(), " not supported"); - TORCH_CHECK(weight_scale.dtype() == input_scale.dtype(), - __func__, " : Expected weight scale datatype ", - input_scale.dtype(), ", got ", weight_scale.dtype()); - if (bias.numel() > 0) { - TORCH_CHECK(bias.dtype() == input_scale.dtype(), - __func__, " : Expected bias datatype ", input_scale.dtype(), - ", got ", bias.dtype()); +template +static void +dispatch_on_tensor_c( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + if (tensor_c.numel() == 0) { + using ElementC = ElementOutput; + using UseTensorC = std::false_type; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } + + using UseTensorC = std::true_type; + if (tensor_c.scalar_type() == at::ScalarType::Half) { + using ElementC = cutlass::half_t; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { + using ElementC = cutlass::bfloat16_t; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; } + TORCH_CHECK(false, + __func__, " : Operator not supported for datatype ", + tensor_c.scalar_type(), " for addend"); +} + +static void +dispatch_on_tensor_a_scale_and_tensor_b_scale( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(), + __func__, " : Operator not supported for output datatype ", + tensor_d.scalar_type(), " as it's different from the first ", + " operand scale datatype ", tensor_a_scale.scalar_type()); + + if (tensor_a_scale.scalar_type() == at::ScalarType::Half && + tensor_b_scale.scalar_type() == at::ScalarType::Half) { + using ElementAScale = cutlass::half_t; + using ElementBScale = cutlass::half_t; + using ElementOutput = cutlass::half_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && + tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) { + using ElementAScale = cutlass::bfloat16_t; + using ElementBScale = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } + + TORCH_CHECK(false, + __func__, " : Operator not supported for combination of data ", + "types ", tensor_a_scale.scalar_type(), + " for first operand scale and ", tensor_b_scale.scalar_type(), + " for second operand scale"); +} + +void +check_inputs( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { // Validate layouts of arguments. - TORCH_CHECK(input.dim() >= 2, - __func__, " : Expected input argument to be 2D or " - "higher-dimensional tensor, got ", input.dim(), " dims"); - TORCH_CHECK(input.layout() == at::Layout::Strided, - __func__, " : Expected input argument to be strided, got layout ", - input.layout()); - TORCH_CHECK(input_scale.dim() == input.dim() - 1, - __func__, " : Expected input scale argument to be ", - input.dim() - 1, "D tensor, got ", input_scale.dim(), " dims"); - TORCH_CHECK(input_scale.layout() == at::Layout::Strided, - __func__, " : Expected input scale argument to be strided, got " - "layout ", input_scale.layout()); - TORCH_CHECK(weight.dim() == 2, - __func__, " : Expected weight argument to be 2D tensor, got ", - weight.dim(), " dims"); - TORCH_CHECK(weight.layout() == at::Layout::Strided, - __func__, - " : Expected weight argument to be strided, got layout ", - weight.layout()); - TORCH_CHECK(weight_scale.dim() == 1 || weight_scale.dim() == 2, - __func__, " : Expected weight scale argument to be 1D or 2D ", - "tensor, got ", weight_scale.dim(), " dims"); - TORCH_CHECK(weight_scale.layout() == at::Layout::Strided, - __func__, " : Expected weight scale argument to be strided, got " - "layout ", weight_scale.layout()); + TORCH_CHECK(xq.dim() >= 2, + __func__, " : Expected xq argument to be 2D or " + "higher-dimensional tensor, got ", xq.dim(), " dims"); + TORCH_CHECK(xq.layout() == at::Layout::Strided, + __func__, " : Expected xq argument to be strided, got layout ", + xq.layout()); + TORCH_CHECK(x_scale.dim() == xq.dim() - 1, + __func__, " : Expected xq scale argument to be ", xq.dim() - 1, + "D tensor, got ", x_scale.dim(), " dims"); + TORCH_CHECK(x_scale.layout() == at::Layout::Strided, + __func__, " : Expected xq scale argument to be strided, got " + "layout ", x_scale.layout()); + TORCH_CHECK(wq.dim() == 2, + __func__, " : Expected wq argument to be 2D tensor, got ", + wq.dim(), " dims"); + TORCH_CHECK(wq.layout() == at::Layout::Strided, + __func__, " : Expected wq argument to be strided, got layout ", + wq.layout()); + TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, + __func__, " : Expected wq scale argument to be 1D or 2D tensor, ", + "got ", w_scale.dim(), " dims"); + TORCH_CHECK(w_scale.layout() == at::Layout::Strided, + __func__, " : Expected wq scale argument to be strided, got " + "layout ", w_scale.layout()); if (bias.numel() > 0) { TORCH_CHECK(bias.dim() == 1, __func__, " : Expected bias argument to be 1D tensor, got ", @@ -412,116 +481,92 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, "layout ", bias.layout()); } - // Squash the input tensor to 2D tensor. - const auto input_sizes = input.sizes().vec(); - const auto input_2d = input.reshape({-1, input_sizes.back()}); - const auto input_scale_sizes = input_scale.sizes().vec(); - const auto input_scale_1d = input_scale.reshape({-1}); - const auto weight_scale_1d = weight_scale.reshape({-1}); - // Validate sizes of arguments. - TORCH_CHECK(input_2d.size(1) == 2 * weight.size(1), - __func__, " : Expected input argument to have ", - 2 * weight.size(1), " columns, but got ", input_2d.size(1)); - for (auto i = 0; i < input_scale_sizes.size(); ++i) - TORCH_CHECK(input_scale_sizes[i] == input_sizes[i], - __func__, " : Expected input scale argument size at position ", - i, " to be ", input_sizes[i], ", but got ", - input_scale_sizes[i]); - TORCH_CHECK(weight_scale_1d.numel() == weight.size(0), - __func__, " : Expected weight scale argument to have ", - weight.size(0), " elements, got ", weight_scale_1d.numel(), - " elements"); + const auto xq_sizes = xq.sizes().vec(); + TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1), + __func__, " : Expected xq argument to have ", 2 * wq.size(1), + " columns, but got ", xq_sizes.back()); + const auto x_scale_sizes = x_scale.sizes().vec(); + for (auto i = 0; i < x_scale_sizes.size(); ++i) + TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], + __func__, " : Expected xq scale argument size at position ", + i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]); + TORCH_CHECK(w_scale.numel() == wq.size(0), + __func__, " : Expected wq scale argument to have ", wq.size(0), + " elements, got ", w_scale.numel(), " elements"); if (bias.numel() > 0) { - TORCH_CHECK(bias.numel() == weight.size(0), - __func__, " : Expected bias argument to have ", weight.size(0), + TORCH_CHECK(bias.numel() == wq.size(0), + __func__, " : Expected bias argument to have ", wq.size(0), " elements, got ", bias.numel(), " elements"); } // Validate strides of arguments. - const auto input_2d_strides = input_2d.strides(); - TORCH_CHECK(input_2d_strides[0] >= 1 && input_2d_strides[1] == 1, - __func__, " : Expected input argument in row-major layout"); - const auto input_scale_1d_strides = input_scale_1d.strides(); - TORCH_CHECK(input_scale_1d_strides[0] == 1, - __func__, " : Expected input scale argument to be contiguous"); - const auto weight_strides = weight.strides(); - TORCH_CHECK(weight_strides[0] >= 1 && weight_strides[1] == 1, - __func__, " : Expected weight argument in row-major layout"); - const auto weight_scale_1d_strides = weight_scale_1d.strides(); - TORCH_CHECK(weight_scale_1d_strides[0] == 1, - __func__, " : Expected weight scale argument to be contiguous"); + const auto xq_strides = xq.strides(); + TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, + __func__, " : Expected xq argument in row-major layout"); + auto xq_stride_expected = xq_strides[xq_strides.size() - 2]; + for (int i = xq_strides.size() - 3; i >= 0; --i) { + xq_stride_expected *= xq_sizes[i + 1]; + TORCH_CHECK(xq_strides[i] == xq_stride_expected, + __func__, " : Expected xq argument in row-major layout"); + } + TORCH_CHECK(x_scale.is_contiguous(), + __func__, " : Expected xq scale argument to be contiguous"); + const auto wq_strides = wq.strides(); + TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, + __func__, " : Expected wq argument in row-major layout"); + TORCH_CHECK(w_scale.is_contiguous(), + __func__, " : Expected wq scale argument to be contiguous"); if (bias.numel() > 0) { const auto bias_strides = bias.strides(); TORCH_CHECK(bias_strides[0] == 1, __func__, " : Expected bias argument to be contiguous"); } +} +#endif + +// Perform linear operation, using corresponding CUTLASS mixed +// data-types GEMM kernel, to given arguments: +// result = (xq * x_scale) @ (wq * w_scale).T + bias +// Notes: The "x_scale" tensor is expected to be a vector, of size +// equal to number of rows of "xq" tensor. The "w_scale" tensor is +// expected to be a vector, of size equal to number of rows of "wq" +// tensor. The "bias" tensor is expected to be a vector, of size equal +// to number of rows of "wq" tensor. +at::Tensor +s8s4_linear_cutlass( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { +#if defined(BUILD_S8S4_LINEAR_CUTLASS) + // Check inputs. + check_inputs(xq, x_scale, wq, w_scale, bias); + + // Squash the input tensors as appropriate. + const auto xq_sizes = xq.sizes().vec(); + const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); + const auto x_scale_sizes = x_scale.sizes().vec(); + const auto x_scale_1d = x_scale.reshape({-1}); + const auto w_scale_1d = w_scale.reshape({-1}); // Introduce alias names for arguments, according to the CUTLASS // naming conventions. - const auto& tensor_a = input_2d; - const auto& tensor_a_scale = input_scale_1d; - const auto& tensor_b = weight; - const auto& tensor_b_scale = weight_scale_1d; + const auto& tensor_a = xq_2d; + const auto& tensor_a_scale = x_scale_1d; + const auto& tensor_b = wq; + const auto& tensor_b_scale = w_scale_1d; const auto& tensor_c = bias; // Create output tensor. at::Tensor tensor_d = tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); - using ElementA = int8_t; - using ElementB = cutlass::int4b_t; - using ElementAccumulator = int32_t; - AT_DISPATCH_SWITCH( - input_scale.scalar_type(), - "s8s4_linear_cutlass", - AT_DISPATCH_CASE( - at::ScalarType::Half, - [&]() { - using ElementAScale = cutlass::half_t; - using ElementBScale = cutlass::half_t; - using ElementC = cutlass::half_t; - using ElementEpilogue = float; - using ElementOutput = cutlass::half_t; - if (bias.numel() > 0) { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, true>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, false>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - }) - AT_DISPATCH_CASE( - at::ScalarType::BFloat16, - [&]() { - using ElementAScale = cutlass::bfloat16_t; - using ElementBScale = cutlass::bfloat16_t; - using ElementC = cutlass::bfloat16_t; - using ElementEpilogue = float; - using ElementOutput = cutlass::bfloat16_t; - if (bias.numel() > 0) { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, true>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, false>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - })); - - auto tensor_d_sizes = input_sizes; - tensor_d_sizes.back() = weight.size(0); + // Dispatch to appropriate kernel template. + dispatch_on_tensor_a_scale_and_tensor_b_scale( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + + // Reshape and return output tensor. + auto tensor_d_sizes = xq_sizes; + tensor_d_sizes.back() = wq.size(0); return tensor_d.reshape(tensor_d_sizes); #else TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 3dd7971525..eb31cba619 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -43,7 +43,7 @@ def sparsify_( apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ) -> torch.nn.Module: - """Convert the weight of linear modules in the model with `apply_tensor_subclass` + """Convert the weight of linear modules in the model with `apply_tensor_subclass`. This function is essentially the same as quantize, put for sparsity subclasses. Currently, we support three options for sparsity: @@ -54,26 +54,26 @@ def sparsify_( Args: model (torch.nn.Module): input model apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance) - filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on - the weight of the module + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module - Example:: - import torch - import torch.nn as nn - from torchao.sparsity import sparsify_ + **Example:** + :: + import torch + import torch.nn as nn + from torchao.sparsity import sparsify_ - def filter_fn(module: nn.Module, fqn: str) -> bool: - return isinstance(module, nn.Linear) + def filter_fn(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Linear) - m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) + m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - # for 2:4 sparsity - from torchao.sparse_api import semi_sparse_weight - m = sparsify_(m, semi_sparse_weight(), filter_fn) + # for 2:4 sparsity + from torchao.sparse_api import semi_sparse_weight + m = sparsify_(m, semi_sparse_weight(), filter_fn) - # for int8 dynamic quantization + 2:4 sparsity - from torchao.dtypes import SemiSparseLayout - m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) + # for int8 dynamic quantization + 2:4 sparsity + from torchao.dtypes import SemiSparseLayout + m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) """ _replace_with_custom_fn_if_matches_filter( model,