diff --git a/docs/source/api_ref_sparsity.rst b/docs/source/api_ref_sparsity.rst index 8023d0bacc..33c652390d 100644 --- a/docs/source/api_ref_sparsity.rst +++ b/docs/source/api_ref_sparsity.rst @@ -12,7 +12,7 @@ torchao.sparsity WandaSparsifier PerChannelNormObserver - apply_sparse_semi_structured apply_fake_sparsity - - + sparsify_ + semi_sparse_weight + int8_dynamic_activation_int8_semi_sparse_weight diff --git a/docs/source/sparsity.rst b/docs/source/sparsity.rst index 3a6e6ab048..0bde173b6d 100644 --- a/docs/source/sparsity.rst +++ b/docs/source/sparsity.rst @@ -48,7 +48,7 @@ Our workflow is designed to consist of two parts that answer each question indep The handoff point between these two pieces are sparse weights stored in a dense format, with 0 in the place of missing elements. This is a natural handoff point because sparse matrix multiplication and dense matrix multiplication with this tensor will be numerically equivalent. This lets us present a clear contract to the user for our backend, for a given sparsity pattern: -**\ *If you can get your dense matrix into a [2:4 sparse format], we can speed up matrix multiplication up to [1.7x] with no numerical loss.*\ ** +If you can get your dense matrix into a **2:4 sparse format**, we can speed up matrix multiplication up to **1.7x** with no numerical loss. This also allows users with existing sparse weights in a dense format to take advantage of our fast sparse kernels. We anticipate many users to come up with their own custom frontend masking solution or to use another third party solution, as this is an active area of research. @@ -102,9 +102,9 @@ Context This section provides some context on neural network pruning/sparsity as well as definitions for some common pruning/sparsity terms. In academia / industry, **pruning** and **sparsity** are often used interchangeably to refer to the same thing. This can be confusing, especially since sparsity is an overloaded term that can refer to many other things, such as sparse tensor representations. -Note that this section focuses on **pruning**\ , instead of **sparse training**. The distinction being that in **pruning** we start with a pretrained dense model, while during **sparse training** we train a sparse model from scratch. +Note that this section focuses on **pruning**, instead of **sparse training**. The distinction being that in **pruning** we start with a pretrained dense model, while during **sparse training** we train a sparse model from scratch. -**In order to avoid confusion, we generally try to use sparsity to refer to tensors. Note that a sparse tensor can refer to a dense tensor with many zero values, or a tensor stored using a sparse representation. We describe the flow as *pruning* and the resultant model as a *pruned* model.** +In order to avoid confusion, we generally try to use sparsity to refer to tensors. Note that a sparse tensor can refer to a dense tensor with many zero values, or a tensor stored using a sparse representation. We describe the flow as **pruning** and the resultant model as a **pruned** model. Roughly, the flow for achieving a more performant pruned model looks like this: diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu index 2daefb7773..411343f0da 100644 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -29,26 +29,35 @@ namespace torchao { #if defined(BUILD_S8S4_LINEAR_CUTLASS) template< - typename ElementA, - typename ElementAScale, - typename ElementB, - typename ElementBScale, - typename ElementC, - typename ElementAccumulator, - typename ElementEpilogue, - typename ElementOutput, typename ThreadblockShape, typename WarpShape, typename InstructionShape, int NumStages, - bool use_tensor_c> -void s8s4_linear_kernel_cutlass( + typename ElementA, + typename ElementB, + typename ElementAccumulator, + typename Operator, + typename ElementAScale, + typename ElementBScale, + typename ElementC, + typename UseTensorC, + typename ElementOutput> +void s8s4_linear_kernel_cutlass_sm8x( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { + using SmArch = cutlass::arch::Sm80; + using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementEpilogue = float; + + using ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + + constexpr auto NumEVTEpilogueStages = 1; const int m = tensor_a.size(0); const int n = tensor_b.size(0); @@ -56,13 +65,13 @@ void s8s4_linear_kernel_cutlass( constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentAScale = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentBScale = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentOutput = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; // Check for current CUTLASS limitations w.r.t. alignments. TORCH_CHECK(k % AlignmentA == 0, @@ -75,12 +84,6 @@ void s8s4_linear_kernel_cutlass( __func__, " : Number of columns of tensor C must be divisible ", "by ", AlignmentC); - using SmArch = cutlass::arch::Sm80; - using ThreadblockSwizzle = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; - - constexpr auto NumEVTEpilogueStages = 1; - using TensorAScaleTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< ThreadblockShape, @@ -132,9 +135,9 @@ void s8s4_linear_kernel_cutlass( cutlass::epilogue::threadblock::VisitorRowBroadcast< TensorCTileThreadMap, ElementC, - cute::Stride>; + cute::Stride>; using TensorC = - std::conditional_t; + std::conditional_t; using TensorCArguments = typename TensorC::Arguments; using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< @@ -178,7 +181,7 @@ void s8s4_linear_kernel_cutlass( typename cutlass::gemm::kernel::DefaultGemmWithVisitor< ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, - ElementC, LayoutC, AlignmentC, + ElementOutput, LayoutOutput, AlignmentOutput, ElementAccumulator, ElementEpilogue, cutlass::arch::OpClassTensorOp, @@ -189,7 +192,7 @@ void s8s4_linear_kernel_cutlass( EVTOutput, ThreadblockSwizzle, NumStages, - cutlass::arch::OpMultiplyAddMixedInputUpcast, + Operator, NumEVTEpilogueStages >::GemmKernel; @@ -210,7 +213,7 @@ void s8s4_linear_kernel_cutlass( }; TensorCArguments tensor_c_arguments{ [&]() -> TensorCArguments { - if constexpr (use_tensor_c) { + if constexpr (UseTensorC::value) { return {(ElementC*)tensor_c.data_ptr(), ElementC(0), {cute::_0{}, cute::_1{}, problem_size.n()}}; @@ -282,127 +285,193 @@ void s8s4_linear_kernel_cutlass( // Perform mixed datatypes GEMM operation. status = gemm_op.run(at::cuda::getCurrentCUDAStream()); CUTLASS_STATUS_CHECK(status); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template< - typename ElementA, - typename ElementAScale, - typename ElementB, - typename ElementBScale, - typename ElementC, - typename ElementAccumulator, - typename ElementEpilogue, - typename ElementOutput, - bool use_tensor_c> -void -s8s4_linear_cutlass_dispatch_shapes( +template +static void select_config( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm8x = dprops->major == 8; + + if (is_sm8x) { + if constexpr (std::is_same::value && + std::is_same::value) { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + // A minimal heuristic to improve performance for small number + // of inputs cases. + if (tensor_a.size(0) <= 16) { + using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; + constexpr auto NumStages = 6; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else if (tensor_a.size(0) <= 32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + constexpr auto NumStages = 5; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + constexpr auto NumStages = 4; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } + return; + } + } + + TORCH_CHECK(false, + __func__, " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); +} + +template +static void +dispatch_on_tensor_a_and_tensor_b( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - - // A minimal heuristic to improve performance for small number of - // inputs cases. - if (tensor_a.size(0) <= 16) { - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; - constexpr auto NumStages = 6; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else if (tensor_a.size(0) <= 32) { - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; - constexpr auto NumStages = 5; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; - constexpr auto NumStages = 4; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( + if (tensor_a.scalar_type() == at::ScalarType::Char) { + if (tensor_b.scalar_type() == at::ScalarType::Char) { + if (tensor_a.size(1) == 2 * tensor_b.size(1)) { + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; + select_config< + ElementA, ElementB, ElementAccumulator, Operator, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + } + return; + } } + + TORCH_CHECK(false, + __func__, " : Operator not supported for combination of data ", + "types ", tensor_a.scalar_type(), " for first operand and ", + tensor_b.scalar_type(), " for second operand"); } -#endif -// Perform linear operation, using corresponding CUTLASS mixed -// data-types GEMM kernel, to given arguments: -// result = (input * input_scale) @ (weight * weight_scale).T + bias -// Notes: The "input_scale" tensor is expected to be a vector, of size -// equal to number of rows of "input" tensor. The "weight_scale" -// tensor is expected to be a vector, of size equal to number of rows -// of "weight" tensor. The "bias" tensor is expected to be a vector, -// of size equal to number of rows of "weight" tensor. -at::Tensor -s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, - const at::Tensor& weight, const at::Tensor& weight_scale, - const at::Tensor& bias) { -#if defined(BUILD_S8S4_LINEAR_CUTLASS) - // For now, only CC 8.x devices are supported. - const auto dprops = at::cuda::getCurrentDeviceProperties(); - const auto is_sm8x = dprops->major == 8; - TORCH_CHECK(is_sm8x, - __func__, " : Supported only on GPUs with compute capability " - "8.x"); - - // Validate datatypes of arguments. - TORCH_CHECK(input.dtype() == at::kChar, - __func__, " : The input datatype ", input.dtype(), - " not supported"); - TORCH_CHECK(input_scale.dtype() == at::kHalf || - input_scale.dtype() == at::kBFloat16, - __func__, " : The input scale datatype ", input_scale.dtype(), - " not supported"); - TORCH_CHECK(weight.dtype() == at::kChar, " : The weight datatype ", - weight.dtype(), " not supported"); - TORCH_CHECK(weight_scale.dtype() == input_scale.dtype(), - __func__, " : Expected weight scale datatype ", - input_scale.dtype(), ", got ", weight_scale.dtype()); - if (bias.numel() > 0) { - TORCH_CHECK(bias.dtype() == input_scale.dtype(), - __func__, " : Expected bias datatype ", input_scale.dtype(), - ", got ", bias.dtype()); +template +static void +dispatch_on_tensor_c( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + if (tensor_c.numel() == 0) { + using ElementC = ElementOutput; + using UseTensorC = std::false_type; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } + + using UseTensorC = std::true_type; + if (tensor_c.scalar_type() == at::ScalarType::Half) { + using ElementC = cutlass::half_t; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { + using ElementC = cutlass::bfloat16_t; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; } + TORCH_CHECK(false, + __func__, " : Operator not supported for datatype ", + tensor_c.scalar_type(), " for addend"); +} + +static void +dispatch_on_tensor_a_scale_and_tensor_b_scale( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(), + __func__, " : Operator not supported for output datatype ", + tensor_d.scalar_type(), " as it's different from the first ", + " operand scale datatype ", tensor_a_scale.scalar_type()); + + if (tensor_a_scale.scalar_type() == at::ScalarType::Half && + tensor_b_scale.scalar_type() == at::ScalarType::Half) { + using ElementAScale = cutlass::half_t; + using ElementBScale = cutlass::half_t; + using ElementOutput = cutlass::half_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && + tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) { + using ElementAScale = cutlass::bfloat16_t; + using ElementBScale = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } + + TORCH_CHECK(false, + __func__, " : Operator not supported for combination of data ", + "types ", tensor_a_scale.scalar_type(), + " for first operand scale and ", tensor_b_scale.scalar_type(), + " for second operand scale"); +} + +void +check_inputs( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { // Validate layouts of arguments. - TORCH_CHECK(input.dim() >= 2, - __func__, " : Expected input argument to be 2D or " - "higher-dimensional tensor, got ", input.dim(), " dims"); - TORCH_CHECK(input.layout() == at::Layout::Strided, - __func__, " : Expected input argument to be strided, got layout ", - input.layout()); - TORCH_CHECK(input_scale.dim() == input.dim() - 1, - __func__, " : Expected input scale argument to be ", - input.dim() - 1, "D tensor, got ", input_scale.dim(), " dims"); - TORCH_CHECK(input_scale.layout() == at::Layout::Strided, - __func__, " : Expected input scale argument to be strided, got " - "layout ", input_scale.layout()); - TORCH_CHECK(weight.dim() == 2, - __func__, " : Expected weight argument to be 2D tensor, got ", - weight.dim(), " dims"); - TORCH_CHECK(weight.layout() == at::Layout::Strided, - __func__, - " : Expected weight argument to be strided, got layout ", - weight.layout()); - TORCH_CHECK(weight_scale.dim() == 1 || weight_scale.dim() == 2, - __func__, " : Expected weight scale argument to be 1D or 2D ", - "tensor, got ", weight_scale.dim(), " dims"); - TORCH_CHECK(weight_scale.layout() == at::Layout::Strided, - __func__, " : Expected weight scale argument to be strided, got " - "layout ", weight_scale.layout()); + TORCH_CHECK(xq.dim() >= 2, + __func__, " : Expected xq argument to be 2D or " + "higher-dimensional tensor, got ", xq.dim(), " dims"); + TORCH_CHECK(xq.layout() == at::Layout::Strided, + __func__, " : Expected xq argument to be strided, got layout ", + xq.layout()); + TORCH_CHECK(x_scale.dim() == xq.dim() - 1, + __func__, " : Expected xq scale argument to be ", xq.dim() - 1, + "D tensor, got ", x_scale.dim(), " dims"); + TORCH_CHECK(x_scale.layout() == at::Layout::Strided, + __func__, " : Expected xq scale argument to be strided, got " + "layout ", x_scale.layout()); + TORCH_CHECK(wq.dim() == 2, + __func__, " : Expected wq argument to be 2D tensor, got ", + wq.dim(), " dims"); + TORCH_CHECK(wq.layout() == at::Layout::Strided, + __func__, " : Expected wq argument to be strided, got layout ", + wq.layout()); + TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, + __func__, " : Expected wq scale argument to be 1D or 2D tensor, ", + "got ", w_scale.dim(), " dims"); + TORCH_CHECK(w_scale.layout() == at::Layout::Strided, + __func__, " : Expected wq scale argument to be strided, got " + "layout ", w_scale.layout()); if (bias.numel() > 0) { TORCH_CHECK(bias.dim() == 1, __func__, " : Expected bias argument to be 1D tensor, got ", @@ -412,116 +481,92 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, "layout ", bias.layout()); } - // Squash the input tensor to 2D tensor. - const auto input_sizes = input.sizes().vec(); - const auto input_2d = input.reshape({-1, input_sizes.back()}); - const auto input_scale_sizes = input_scale.sizes().vec(); - const auto input_scale_1d = input_scale.reshape({-1}); - const auto weight_scale_1d = weight_scale.reshape({-1}); - // Validate sizes of arguments. - TORCH_CHECK(input_2d.size(1) == 2 * weight.size(1), - __func__, " : Expected input argument to have ", - 2 * weight.size(1), " columns, but got ", input_2d.size(1)); - for (auto i = 0; i < input_scale_sizes.size(); ++i) - TORCH_CHECK(input_scale_sizes[i] == input_sizes[i], - __func__, " : Expected input scale argument size at position ", - i, " to be ", input_sizes[i], ", but got ", - input_scale_sizes[i]); - TORCH_CHECK(weight_scale_1d.numel() == weight.size(0), - __func__, " : Expected weight scale argument to have ", - weight.size(0), " elements, got ", weight_scale_1d.numel(), - " elements"); + const auto xq_sizes = xq.sizes().vec(); + TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1), + __func__, " : Expected xq argument to have ", 2 * wq.size(1), + " columns, but got ", xq_sizes.back()); + const auto x_scale_sizes = x_scale.sizes().vec(); + for (auto i = 0; i < x_scale_sizes.size(); ++i) + TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], + __func__, " : Expected xq scale argument size at position ", + i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]); + TORCH_CHECK(w_scale.numel() == wq.size(0), + __func__, " : Expected wq scale argument to have ", wq.size(0), + " elements, got ", w_scale.numel(), " elements"); if (bias.numel() > 0) { - TORCH_CHECK(bias.numel() == weight.size(0), - __func__, " : Expected bias argument to have ", weight.size(0), + TORCH_CHECK(bias.numel() == wq.size(0), + __func__, " : Expected bias argument to have ", wq.size(0), " elements, got ", bias.numel(), " elements"); } // Validate strides of arguments. - const auto input_2d_strides = input_2d.strides(); - TORCH_CHECK(input_2d_strides[0] >= 1 && input_2d_strides[1] == 1, - __func__, " : Expected input argument in row-major layout"); - const auto input_scale_1d_strides = input_scale_1d.strides(); - TORCH_CHECK(input_scale_1d_strides[0] == 1, - __func__, " : Expected input scale argument to be contiguous"); - const auto weight_strides = weight.strides(); - TORCH_CHECK(weight_strides[0] >= 1 && weight_strides[1] == 1, - __func__, " : Expected weight argument in row-major layout"); - const auto weight_scale_1d_strides = weight_scale_1d.strides(); - TORCH_CHECK(weight_scale_1d_strides[0] == 1, - __func__, " : Expected weight scale argument to be contiguous"); + const auto xq_strides = xq.strides(); + TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, + __func__, " : Expected xq argument in row-major layout"); + auto xq_stride_expected = xq_strides[xq_strides.size() - 2]; + for (int i = xq_strides.size() - 3; i >= 0; --i) { + xq_stride_expected *= xq_sizes[i + 1]; + TORCH_CHECK(xq_strides[i] == xq_stride_expected, + __func__, " : Expected xq argument in row-major layout"); + } + TORCH_CHECK(x_scale.is_contiguous(), + __func__, " : Expected xq scale argument to be contiguous"); + const auto wq_strides = wq.strides(); + TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, + __func__, " : Expected wq argument in row-major layout"); + TORCH_CHECK(w_scale.is_contiguous(), + __func__, " : Expected wq scale argument to be contiguous"); if (bias.numel() > 0) { const auto bias_strides = bias.strides(); TORCH_CHECK(bias_strides[0] == 1, __func__, " : Expected bias argument to be contiguous"); } +} +#endif + +// Perform linear operation, using corresponding CUTLASS mixed +// data-types GEMM kernel, to given arguments: +// result = (xq * x_scale) @ (wq * w_scale).T + bias +// Notes: The "x_scale" tensor is expected to be a vector, of size +// equal to number of rows of "xq" tensor. The "w_scale" tensor is +// expected to be a vector, of size equal to number of rows of "wq" +// tensor. The "bias" tensor is expected to be a vector, of size equal +// to number of rows of "wq" tensor. +at::Tensor +s8s4_linear_cutlass( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { +#if defined(BUILD_S8S4_LINEAR_CUTLASS) + // Check inputs. + check_inputs(xq, x_scale, wq, w_scale, bias); + + // Squash the input tensors as appropriate. + const auto xq_sizes = xq.sizes().vec(); + const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); + const auto x_scale_sizes = x_scale.sizes().vec(); + const auto x_scale_1d = x_scale.reshape({-1}); + const auto w_scale_1d = w_scale.reshape({-1}); // Introduce alias names for arguments, according to the CUTLASS // naming conventions. - const auto& tensor_a = input_2d; - const auto& tensor_a_scale = input_scale_1d; - const auto& tensor_b = weight; - const auto& tensor_b_scale = weight_scale_1d; + const auto& tensor_a = xq_2d; + const auto& tensor_a_scale = x_scale_1d; + const auto& tensor_b = wq; + const auto& tensor_b_scale = w_scale_1d; const auto& tensor_c = bias; // Create output tensor. at::Tensor tensor_d = tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); - using ElementA = int8_t; - using ElementB = cutlass::int4b_t; - using ElementAccumulator = int32_t; - AT_DISPATCH_SWITCH( - input_scale.scalar_type(), - "s8s4_linear_cutlass", - AT_DISPATCH_CASE( - at::ScalarType::Half, - [&]() { - using ElementAScale = cutlass::half_t; - using ElementBScale = cutlass::half_t; - using ElementC = cutlass::half_t; - using ElementEpilogue = float; - using ElementOutput = cutlass::half_t; - if (bias.numel() > 0) { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, true>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, false>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - }) - AT_DISPATCH_CASE( - at::ScalarType::BFloat16, - [&]() { - using ElementAScale = cutlass::bfloat16_t; - using ElementBScale = cutlass::bfloat16_t; - using ElementC = cutlass::bfloat16_t; - using ElementEpilogue = float; - using ElementOutput = cutlass::bfloat16_t; - if (bias.numel() > 0) { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, true>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, false>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - })); - - auto tensor_d_sizes = input_sizes; - tensor_d_sizes.back() = weight.size(0); + // Dispatch to appropriate kernel template. + dispatch_on_tensor_a_scale_and_tensor_b_scale( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + + // Reshape and return output tensor. + auto tensor_d_sizes = xq_sizes; + tensor_d_sizes.back() = wq.size(0); return tensor_d.reshape(tensor_d_sizes); #else TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 3dd7971525..eb31cba619 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -43,7 +43,7 @@ def sparsify_( apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ) -> torch.nn.Module: - """Convert the weight of linear modules in the model with `apply_tensor_subclass` + """Convert the weight of linear modules in the model with `apply_tensor_subclass`. This function is essentially the same as quantize, put for sparsity subclasses. Currently, we support three options for sparsity: @@ -54,26 +54,26 @@ def sparsify_( Args: model (torch.nn.Module): input model apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance) - filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on - the weight of the module + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module - Example:: - import torch - import torch.nn as nn - from torchao.sparsity import sparsify_ + **Example:** + :: + import torch + import torch.nn as nn + from torchao.sparsity import sparsify_ - def filter_fn(module: nn.Module, fqn: str) -> bool: - return isinstance(module, nn.Linear) + def filter_fn(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Linear) - m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) + m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - # for 2:4 sparsity - from torchao.sparse_api import semi_sparse_weight - m = sparsify_(m, semi_sparse_weight(), filter_fn) + # for 2:4 sparsity + from torchao.sparse_api import semi_sparse_weight + m = sparsify_(m, semi_sparse_weight(), filter_fn) - # for int8 dynamic quantization + 2:4 sparsity - from torchao.dtypes import SemiSparseLayout - m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) + # for int8 dynamic quantization + 2:4 sparsity + from torchao.dtypes import SemiSparseLayout + m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) """ _replace_with_custom_fn_if_matches_filter( model,