diff --git a/.github/workflows/float8nocompile_test.yaml b/.github/workflows/float8nocompile_test.yaml deleted file mode 100644 index b8707c148e..0000000000 --- a/.github/workflows/float8nocompile_test.yaml +++ /dev/null @@ -1,53 +0,0 @@ -name: Run Float8nocompile Tests - -on: - push: - branches: - - main - - 'gh/**' - paths: - - 'torchao/prototype/float8nocompile/**' - pull_request: - branches: - - main - - 'gh/**' - paths: - - 'torchao/prototype/float8nocompile/**' - -concurrency: - group: floatnocompile_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} - cancel-in-progress: true - -env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - -# jobs: -# test: -# strategy: -# fail-fast: false -# matrix: -# include: -# - name: H100 -# runs-on: linux.aws.h100 -# torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124' -# gpu-arch-type: "cuda" -# gpu-arch-version: "12.4" - -# uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main -# with: -# timeout: 300 -# runner: ${{ matrix.runs-on }} -# gpu-arch-type: ${{ matrix.gpu-arch-type }} -# gpu-arch-version: ${{ matrix.gpu-arch-version }} -# submodules: recursive -# script: | -# conda create -n venv python=3.9 -y -# conda activate venv -# export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH -# python -m pip install --upgrade pip -# pip install ${{ matrix.torch-spec }} -# pip install -r dev-requirements.txt -# pip install . -# cd torchao/prototype/float8nocompile -# pytest kernels/ --verbose -s -# pytest test/train_test.py --verbose -s diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 8d274b62e7..2187eed8e3 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -37,10 +37,8 @@ jobs: # of torch and torchao, which we do not want to use pip install executorch pip install torch==2.7.0.dev20250311 --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall - pip install numpy - pip install pytest - pip install parameterized - USE_CPP=1 TOCHAO_BUILD_KLEIDIAI=1 pip install . + pip install -r dev-requirements.txt + USE_CPP=1 TORCHAO_BUILD_KLEIDIAI=1 pip install . - name: Run python tests run: | conda activate venv @@ -99,11 +97,8 @@ jobs: python -c "import torch; print(torch.__version__)" - name: Install requirements run: | - pip install cmake - pip install parameterized - pip install pyyaml - pip install numpy - pip install importlib-metadata + pip install -r dev-requirements.txt + pip install pyyaml importlib-metadata - name: Print pip freeze run: | pip freeze diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 227cb90948..4394d0208b 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -50,3 +50,31 @@ model_params: # device: "cpu" # model_type: "linear" # enable_profiler: true # Enable profiling for this model + + - name: "bf16_rms_norm_linear_activation" + matrix_shapes: + - name: "custom" + shapes: [ + [2048, 4096, 1024], + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "rms_norm_linear_activation" + enable_profiler: true + enable_memory_profile: true + + - name: "bf16_transformer_block" + matrix_shapes: + - name: "custom" + shapes: [ + [2048, 4096, 1024], # For transformer_block, k is the hidden dimension + ] + high_precision_dtype: "torch.bfloat16" + use_torch_compile: true + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "transformer_block" + enable_profiler: true + enable_memory_profile: true diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 14f226bd7e..46f6a74685 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -17,8 +17,11 @@ Float8DynamicActivationFloat8SemiSparseWeightConfig, Int4WeightOnlyConfig, LNLinearSigmoid, + RMSNorm, + RMSNormLinearActivation, SemiSparseWeightConfig, ToyLinearModel, + TransformerBlock, clean_caches, create_model_and_input, generate_results_csv, @@ -162,6 +165,61 @@ def test_ln_linear_sigmoid(self): torch.all((out >= 0) & (out <= 1)) ) # Check sigmoid output range + def test_rms_norm(self): + # Test RMSNorm + rms_norm = RMSNorm(dim=64) + x = torch.randn(16, 64) + out = rms_norm(x) + self.assertEqual(out.shape, (16, 64)) + + # Test with different eps + rms_norm = RMSNorm(dim=64, eps=1e-5) + out = rms_norm(x) + self.assertEqual(out.shape, (16, 64)) + + def test_rms_norm_linear_activation(self): + # Test with default GELU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + x = torch.randn(16, 64) + out = model(x) + self.assertEqual(out.shape, (16, 32)) + self.assertEqual(out.dtype, torch.float32) + + # Test with ReLU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu") + out = model(x) + self.assertEqual(out.shape, (16, 32)) + self.assertTrue(torch.all(out >= 0)) # Check ReLU output range + + # Test with SiLU activation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu") + out = model(x) + self.assertEqual(out.shape, (16, 32)) + + # Test with invalid activation + with self.assertRaises(ValueError): + RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid") + + def test_transformer_block(self): + # Test with default parameters + model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) + x = torch.randn(16, 16, 64) # [batch_size, seq_len, hidden_dim] + out = model(x) + self.assertEqual(out.shape, (16, 16, 64)) + self.assertEqual(out.dtype, torch.float32) + + # Test with different parameters + model = TransformerBlock(hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32) + x = torch.randn(8, 32, 128) + out = model(x) + self.assertEqual(out.shape, (8, 32, 128)) + + # Test with different head dimensions + model = TransformerBlock(hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32) + x = torch.randn(4, 8, 96) + out = model(x) + self.assertEqual(out.shape, (4, 8, 96)) + def test_create_model_and_input(self): m, k, n = 16, 64, 32 model, input_data = create_model_and_input( @@ -186,6 +244,63 @@ def test_create_model_and_input(self): self.assertIsInstance(model, LNLinearSigmoid) self.assertEqual(input_data.shape, (m, k)) + # Test RMSNormLinearActivation + model, input_data = create_model_and_input( + model_type="rms_norm_linear_activation", + m=m, + k=k, + n=n, + high_precision_dtype=torch.float32, + device="cpu", + ) + self.assertIsInstance(model, RMSNormLinearActivation) + self.assertEqual(input_data.shape, (m, k)) + + # Test TransformerBlock + model, input_data = create_model_and_input( + model_type="transformer_block", + m=m, + k=k, + n=n, # n is not used for transformer_block + high_precision_dtype=torch.float32, + device="cpu", + ) + self.assertIsInstance(model, TransformerBlock) + self.assertEqual(input_data.shape, (m, 16, k)) # [batch_size, seq_len, hidden_dim] + + def test_quantization_on_models(self): + # Test quantization on RMSNormLinearActivation + model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + x = torch.randn(16, 64) + + # Test with Int8WeightOnlyConfig + config = string_to_config(quantization="int8wo", sparsity=None) + if config is not None: + # Skip quantization test if torchao.quantization.quantize is not available + try: + from torchao.quantization import quantize + quantized_model = quantize(model, config) + out = quantized_model(x) + self.assertEqual(out.shape, (16, 32)) + except ImportError: + print("Skipping quantization test: torchao.quantization.quantize not available") + + # Test quantization on TransformerBlock + model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32) + x = torch.randn(16, 16, 64) + + # Test with Int8WeightOnlyConfig + config = string_to_config(quantization="int8wo", sparsity=None) + if config is not None: + # Skip quantization test if torchao.quantization.quantize is not available + try: + from torchao.quantization import quantize + quantized_model = quantize(model, config) + out = quantized_model(x) + self.assertEqual(out.shape, (16, 16, 64)) + except ImportError: + print("Skipping quantization test: torchao.quantization.quantize not available") + def test_generate_results_csv(self): results = [ BenchmarkResult( diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 677f66ac75..9e978f70fa 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -383,6 +383,108 @@ def forward(self, x): return x +class RMSNorm(torch.nn.Module): + def __init__(self, dim, eps=1e-6, dtype=torch.bfloat16): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + + def forward(self, x): + norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return x * norm * self.weight + + +class RMSNormLinearActivation(torch.nn.Module): + def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="gelu"): + super().__init__() + self.rms_norm = RMSNorm(fc_dim1, dtype=dtype) + self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype) + + if activation == "gelu": + self.activation = torch.nn.GELU() + elif activation == "relu": + self.activation = torch.nn.ReLU() + elif activation == "silu": + self.activation = torch.nn.SiLU() + else: + raise ValueError(f"Unsupported activation: {activation}") + + def forward(self, x): + x = self.rms_norm(x) + x = self.fc(x) + x = self.activation(x) + return x + + +class TransformerBlock(torch.nn.Module): + def __init__(self, hidden_dim, num_heads=8, mlp_ratio=4, dtype=torch.bfloat16): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + + # Self-attention + self.qkv = torch.nn.Linear(hidden_dim, 3 * hidden_dim, bias=False).to(dtype) + self.proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=False).to(dtype) + + # MLP + self.mlp_ratio = mlp_ratio + self.mlp_hidden_dim = int(hidden_dim * mlp_ratio) + self.mlp_fc1 = torch.nn.Linear(hidden_dim, self.mlp_hidden_dim, bias=False).to(dtype) + self.mlp_fc2 = torch.nn.Linear(self.mlp_hidden_dim, hidden_dim, bias=False).to(dtype) + + # Layer norms + self.norm1 = RMSNorm(hidden_dim, dtype=dtype) + self.norm2 = RMSNorm(hidden_dim, dtype=dtype) + + # Activation + self.activation = torch.nn.GELU() + + def forward(self, x): + batch_size, seq_len, _ = x.shape + + # Self-attention + residual = x + x = self.norm1(x) + + # Reshape qkv projection for better memory layout + qkv = self.qkv(x) # [batch_size, seq_len, 3 * hidden_dim] + qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim] + q, k, v = qkv # Each has shape [batch_size, num_heads, seq_len, head_dim] + + # Scaled dot-product attention with proper reshaping + # Reshape for better memory layout and avoid broadcasting issues + q = q.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + k = k.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + v = v.reshape(batch_size * self.num_heads, seq_len, self.head_dim) + + # Compute attention scores + attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_dim ** 0.5)) + attn = torch.softmax(attn, dim=-1) + + # Apply attention to values + x = attn @ v # [batch_size * num_heads, seq_len, head_dim] + + # Reshape back to original dimensions + x = x.reshape(batch_size, self.num_heads, seq_len, self.head_dim) + x = x.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_dim) + + # Project back to hidden dimension + x = self.proj(x) + x = residual + x + + # MLP + residual = x + x = self.norm2(x) + x = self.mlp_fc1(x) + x = self.activation(x) + x = self.mlp_fc2(x) + x = residual + x + + return x + + def string_to_config( quantization: Optional[str], sparsity: Optional[str], **kwargs ) -> AOBaseConfig: @@ -576,6 +678,14 @@ def create_model_and_input( elif model_type == "ln_linear_sigmoid": model = LNLinearSigmoid(k, n, high_precision_dtype).to(device) input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif model_type == "rms_norm_linear_activation": + model = RMSNormLinearActivation(k, n, high_precision_dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + elif model_type == "transformer_block": + # For transformer block, k is the hidden dimension + model = TransformerBlock(k, num_heads=8, mlp_ratio=4, dtype=high_precision_dtype).to(device) + # Input shape for transformer is [batch_size, seq_len, hidden_dim] + input_data = torch.randn(m, 16, k, device=device, dtype=high_precision_dtype) else: raise ValueError(f"Unknown model type: {model_type}") return model, input_data diff --git a/dev-requirements.txt b/dev-requirements.txt index f5b1599ffa..1982d76795 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -26,6 +26,9 @@ importlib_metadata # Custom CUDA Extensions ninja +# CPU kernels +cmake<4.0.0,>=3.19.0 + # Linting ruff==0.6.8 pre-commit diff --git a/scripts/clean_release_notes.py b/scripts/clean_release_notes.py index 2caef0735b..92ce5996cc 100644 --- a/scripts/clean_release_notes.py +++ b/scripts/clean_release_notes.py @@ -223,7 +223,7 @@ def format_commit(commit_line: str) -> str: After: * Commit title (https://github.com/pytorch/ao/pull/123) """ # Remove author, put PR link in parentheses - commit_line = re.sub(" by @.* in (.*)", r" (\\g<1>)", commit_line) + commit_line = re.sub(" by @.* in (.*)", r" (\g<1>)", commit_line) # Capitalize first letter commit_line = commit_line.lstrip("* ") commit_line = "* " + commit_line[0].upper() + commit_line[1:] diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index a67f7775b1..0ebc356114 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -38,6 +38,7 @@ @pytest.mark.skip("skipping for now, see comments below") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @pytest.mark.parametrize( "dim1,dim2,dtype,signed,blocksize", TEST_CONFIGS, @@ -89,6 +90,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): TEST_CONFIGS, ) @skip_if_rocm("ROCm enablement in progress") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 3c29028898..fcd4969bbf 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -133,6 +133,18 @@ def forward(self, x): return x +class M4(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float) + + def example_inputs(self): + return (torch.randn(1, 512).to(torch.float),) + + def forward(self, x): + return self.linear(x) + + class ModelWithLinearBias(torch.nn.Module): def __init__(self): super().__init__() @@ -1389,6 +1401,65 @@ def test_qat_linear_bias(self): example_inputs = m.example_inputs() m(*example_inputs) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_fake_quantize_per_token_vs_convert(self): + """ + Test that the following produce the exact same numerics: + 1. FakeQuantizer with asymmetric per_token config + 2. torchao.quantization.utils.per_token_dynamic_quant + """ + from torchao.quantization.utils import per_token_dynamic_quant + + torch.manual_seed(self.SEED) + x = torch.randn(1, 235, 2048) + config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + fake_quantizer = FakeQuantizer(config) + fake_quantizer_out = fake_quantizer(x) + baseline_out = per_token_dynamic_quant(x) + torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_qat_8da4w_prepare_vs_convert(self): + """ + Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces + numerics that match exactly over N trials. + """ + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.utils import compute_error + + num_trials = 1000 + group_size = 16 + non_inf_sqnr = [] + + for seed in range(self.SEED, self.SEED + num_trials): + torch.manual_seed(seed) + m = M4() + torch.manual_seed(seed) + x = m.example_inputs() + + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + prepared = quantizer.prepare(m) + prepared_out = prepared(*x) + converted = quantizer.convert(prepared) + converted_out = converted(*x) + sqnr = compute_error(prepared_out, converted_out).item() + if sqnr != float("inf"): + non_inf_sqnr.append(sqnr) + + avg_sqnr = ( + sum(non_inf_sqnr) / len(non_inf_sqnr) if len(non_inf_sqnr) > 0 else -1 + ) + fail_message = "%s/%s trials did not match exactly, average sqnr = %s" % ( + len(non_inf_sqnr), + num_trials, + avg_sqnr, + ) + self.assertEqual(len(non_inf_sqnr), 0, fail_message) + if __name__ == "__main__": unittest.main() diff --git a/torchao/_executorch_ops.py b/torchao/_executorch_ops.py index 29339bba8c..4b761ad725 100644 --- a/torchao/_executorch_ops.py +++ b/torchao/_executorch_ops.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. import torch +# TODO: delete these ops + def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs): """ diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 531e1ba7e6..26f6494220 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -21,6 +21,7 @@ // // MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): // - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory +// - Added proper architecture check at both host and device level // @@ -98,7 +99,24 @@ void fpx_linear_kernel(cudaStream_t stream, static_assert(std::is_same::value || std::is_same::value, "Type must be 'half' or '__nv_bfloat16'"); assert(M_Global % 256 == 0); assert(K_Global % 64 == 0); - assert(N_Global>0); + assert(N_Global > 0); + + // Check GPU Compute Capability before proceeding + int device, major, minor; + CHECK_CUDA(cudaGetDevice(&device)); + CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + + // Early exit with error for unsupported architectures + if ((major < 7) || (major == 7 && minor < 5)) { + TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. " + "Your current device has SM", major, minor, " which is not supported."); + } + + const bool is_sm75_gpu = (major == 7) && (minor == 5); + if (is_sm75_gpu && std::is_same::value) { + TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs."); + } // Work around to support more N shapes: size_t N_PowerOf2; @@ -109,17 +127,6 @@ void fpx_linear_kernel(cudaStream_t stream, if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; - // Check GPU Compute Capability - int device, major, minor; - CHECK_CUDA(cudaGetDevice(&device)); - CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); - CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); - const bool is_sm75_gpu = (major == 7) && (minor == 5); - if (is_sm75_gpu && std::is_same::value) - TORCH_CHECK(false, "Bfloat16 inputs are not supported for SM75"); - if ((major < 7) || (major == 7 && minor < 5)) - TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n"); - if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory. if (Split_K == 1) { @@ -136,7 +143,7 @@ void fpx_linear_kernel(cudaStream_t stream, case 64: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2); + TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2); } Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } @@ -149,7 +156,7 @@ void fpx_linear_kernel(cudaStream_t stream, case 64: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; case 128: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2); + TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2); } Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; } @@ -210,6 +217,23 @@ torch::Tensor fp_eXmY_linear_forward_cuda( torch::Tensor _scales, int64_t splitK=1) { + // Check GPU Compute Capability before proceeding + int device, major, minor; + CHECK_CUDA(cudaGetDevice(&device)); + CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + + // Early exit with error for unsupported architectures + if ((major < 7) || (major == 7 && minor < 5)) { + TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. " + "Your current device has SM", major, minor, " which is not supported."); + } + + const bool is_sm75_gpu = (major == 7) && (minor == 5); + if (is_sm75_gpu && _in_feats.scalar_type() == at::ScalarType::BFloat16) { + TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs."); + } + const int64_t NBITS = 1 + EXPONENT + MANTISSA; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index d4be92b227..096bdc0d7f 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -51,17 +51,14 @@ * B: col major, FP16 * C: col major, FP16 */ - template +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +template __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const half *B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, int Split_K) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 - static_assert(false, "Quant-LLM kernel: At least Turing generation (sm75) is required."); - // __trap(); // fails at runtime instead of compile time - #endif #ifdef DEBUG_MODE assert(K_Global%TilingConfig::TILE_K==0); assert(M_Global%TilingConfig::TILE_M==0); @@ -233,3 +230,15 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, } } } +#else +// Stub implementation for older architectures +template +__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, + const half *B, + OutputDataType* C, + const size_t M_Global, const size_t N_Global, const size_t K_Global, + int Split_K) +{ +// NOOP, should never actually be called +} +#endif diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index f05e6b392f..e6b2a6aff0 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -40,6 +40,7 @@ include_directories(${TORCHAO_INCLUDE_DIRS}) if(TORCHAO_BUILD_CPU_AARCH64) message(STATUS "Building with cpu/aarch64") add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) # Defines torchao_kernels_aarch64 add_subdirectory(kernels/cpu/aarch64) diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index 3cca338cbf..f38794d4a8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -19,7 +19,7 @@ if (TORCHAO_BUILD_CPU_AARCH64) # intelligence (AI) workloads tailored for Arm® CPUs. FetchContent_Declare(kleidiai GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git - GIT_TAG v1.2.0) + GIT_TAG v1.5.0) FetchContent_MakeAvailable(kleidiai) target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 2e8d0aa453..aa338fc165 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -14,9 +14,14 @@ #include #include +#include + +#ifdef TORCHAO_ENABLE_ARM_NEON_DOT +#include #include #include -#include +#include +#endif // TORCHAO_ENABLE_ARM_NEON_DOT #ifdef TORCHAO_ENABLE_ARM_I8MM #include @@ -60,27 +65,47 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; -template -size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { +size_t packed_activations_size( + int m, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr) { (void)group_size; // unused (void)has_weight_zeros; // unused auto lhs_packing = get_lhs_packing(); return lhs_packing.get_lhs_packed_size(m, k, mr, kr, sr); } -template -void prepare_activation_data( - void* activation_data, +size_t packed_activations_offset( + int m_idx, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr) { + (void)group_size; // unused + (void)has_weight_zeros; // unused + auto lhs_pack = get_lhs_packing(); + return lhs_pack.get_lhs_packed_offset(m_idx, k, mr, kr, sr); +} + +void pack_activations( + void* packed_activations, int m, int k, int group_size, const float* activations, - bool has_weight_zeros) { + bool has_weight_zeros, + int mr, + int kr, + int sr) { (void)group_size; // unused (void)has_weight_zeros; // unused auto lhs_pack = get_lhs_packing(); - lhs_pack.run_lhs_pack( m, k, @@ -90,33 +115,62 @@ void prepare_activation_data( /*m_index_start=*/0, activations, /*lhs_stride=*/k * sizeof(float), - activation_data); + packed_activations); } -template -size_t weight_data_size( +size_t packed_weights_size( int n, int k, int group_size, + int weight_nbit, bool has_weight_zeros, - bool has_bias) { + bool has_bias, + int nr, + int kr, + int sr) { + (void)weight_nbit; // unused (void)has_weight_zeros; // unused (void)has_bias; // unused auto rhs_pack = get_rhs_packing(); return rhs_pack.get_rhs_packed_size( - n, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16); + internal::adjust_n(n), + k, + nr, + kr, + sr, + group_size, + kai_datatype::kai_dt_bf16); +} + +size_t packed_weights_offset( + int n_idx, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int sr) { + (void)has_weight_zeros; // unused + (void)has_bias; // unused + auto rhs_pack = get_rhs_packing(); + return rhs_pack.get_rhs_packed_offset( + n_idx, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16); } -template -void prepare_weight_data( - void* weight_data, +void pack_weights( + void* packed_weights, int n, int k, int group_size, const int8_t* weight_qvals, const float* weight_scales, const int8_t* weight_zeros, - const float* bias) { + const float* bias, + int nr, + int kr, + int sr) { if (group_size % 32 != 0) { throw std::runtime_error( "Group size must be a multiple of 32, but got group_size=" + @@ -187,7 +241,7 @@ void prepare_weight_data( reinterpret_cast(weight_scales_bf16_padded.data()), /*scale_stride=*/sizeof(uint16_t) * (internal::roundup(k, group_size) / group_size), - /*rhs_packed=*/weight_data, + /*rhs_packed=*/packed_weights, /*extra_bytes=*/0, /*qparams=*/&qparams); } @@ -220,8 +274,8 @@ size_t get_preferred_alignement() { int n, \ int k, \ int group_size, \ - const void* weight_data, \ - const void* activation_data, \ + const void* packed_weights, \ + const void* packed_activations, \ float clamp_min, \ float clamp_max, \ bool has_weight_zeros, \ @@ -235,11 +289,11 @@ size_t get_preferred_alignement() { } \ get_ukernel().run_matmul( \ m, \ - internal::adjust_n(n), \ + n, \ k, \ group_size, \ - activation_data, \ - weight_data, \ + packed_activations, \ + packed_weights, \ output, \ /*dst_stride_row=*/output_m_stride * sizeof(float), \ /*dst_stride_col=*/sizeof(float), \ @@ -248,10 +302,14 @@ size_t get_preferred_alignement() { } \ } +#ifdef TORCHAO_ENABLE_ARM_NEON_DOT DEFINE_KERNEL_STRUCT( matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod); DEFINE_KERNEL_STRUCT( matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod); +DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod); +DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod); +#endif // TORCHAO_ENABLE_ARM_NEON_DOT #ifdef TORCHAO_ENABLE_ARM_I8MM DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm); diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h index 4ca9cef54d..95ecb79dc0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -49,15 +49,21 @@ inline size_t packed_activations_offset( return (m_idx / mr) * packed_activations_size_mr_rows; } -template +template void pack_activations( void* packed_activations, int m, int k, int group_size, const float* activations, - bool has_weight_zeros) { - activation_packing::pack_activations( + bool has_weight_zeros, + int mr, + int kr, + int sr) { + (void)mr; // unused + (void)kr; // unused + (void)sr; // unused + activation_packing::pack_activations( packed_activations, m, k, group_size, activations, has_weight_zeros); } @@ -93,7 +99,7 @@ inline size_t packed_weights_offset( return (n_idx / nr) * packed_weights_size_nr_cols; } -template +template void pack_weights( void* packed_weights, int n, @@ -102,8 +108,14 @@ void pack_weights( const int8_t* weight_qvals, const float* weight_scales, const int8_t* weight_zeros, - const float* bias) { - weight_packing::pack_weights( + const float* bias, + int nr, + int kr, + int sr) { + (void)nr; // unused + (void)kr; // unused + (void)sr; // unused + weight_packing::pack_weights( packed_weights, n, k, @@ -245,7 +257,7 @@ void kernel_1x4x16_f32_neondot( has_clamp); } -template +template void kernel_1x8x16_f32_neondot( // Outputs float32_t* output, @@ -260,10 +272,11 @@ void kernel_1x8x16_f32_neondot( // Ignored if has_clamp = false float clamp_min, float clamp_max, - bool has_weight_zeros, + bool has_weight_zeros_, bool has_bias, bool has_clamp) { - kernel::kernel_1x8x16_f32_neondot( + (void)has_weight_zeros_; // unused + kernel::kernel_1x8x16_f32_neondot( output, output_m_stride, m, @@ -274,7 +287,6 @@ void kernel_1x8x16_f32_neondot( packed_activations, clamp_min, clamp_max, - has_weight_zeros, has_bias, has_clamp); } diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h index 81f6e6b023..7a53c7302c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h @@ -58,7 +58,7 @@ vec_clamp(float32x4_t x, float32x4_t vec_min, float32x4_t vec_max) { // Roughly inspired by // https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c?ref_type=heads -template +template void kernel_1x8x16_f32_neondot( // Outputs float32_t* output, @@ -73,7 +73,6 @@ void kernel_1x8x16_f32_neondot( // Ignored if has_clamp is false float clamp_min, float clamp_max, - bool has_weight_zeros, bool has_bias, bool has_clamp) { assert(k % group_size == 0); @@ -267,7 +266,7 @@ void kernel_1x8x16_f32_neondot( int32x4_t term1_4567 = vmulq_n_s32(weight_qvals_sum, activation_zero); - if (has_weight_zeros) { + if constexpr (has_weight_zeros) { // Compute term2 and term3 int32_t activation_qvals_sum = *((int32_t*)activation_ptr); diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h deleted file mode 100644 index cd816dba46..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h +++ /dev/null @@ -1,364 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -// TODO: this file will be deleted and replaced by -// torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h -// It exists now to prevent breaking existing code in the interim. - -#pragma once - -#if defined(__aarch64__) || defined(__ARM_NEON) - -#include -#include -#include - -namespace torchao::kernels::cpu::aarch64::linear { -namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot { - -inline size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - packed_activations_size( - m, - k, - group_size, - has_weight_zeros, - /*mr*/ 1, - /*kr*/ 32, - /*sr*/ 1); -} - -inline void prepare_activation_data( - void* activation_data, - // Inputs - int m, - int k, - // Ignored if has_weight_zeros = false - int group_size, - const float* activations, - bool has_weight_zeros) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_activations( - activation_data, m, k, group_size, activations, has_weight_zeros); -} - -template -size_t weight_data_size( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight::packed_weights_size( - n, - k, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - /*nr*/ 1, - /*kr*/ 32, - /*sr*/ 1); -} - -template -void prepare_weight_data( - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_weights( - weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -template -void kernel( - // Outputs - float32_t* output, - // Inputs - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x1x32_f32_neondot( - output, - output_m_stride, - m, - n, - k, - group_size, - weight_data, - activation_data, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); -} - -} // namespace - // channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot - -namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot { - -inline size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - packed_activations_size( - m, - k, - group_size, - has_weight_zeros, - /*mr*/ 1, - /*kr*/ 16, - /*sr*/ 2); -} - -inline void prepare_activation_data( - void* activation_data, - // Inputs - int m, - int k, - // Ignored if has_weight_zeros = false - int group_size, - const float* activations, - bool has_weight_zeros) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_activations( - activation_data, m, k, group_size, activations, has_weight_zeros); -} - -template -inline size_t weight_data_size( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight::packed_weights_size( - n, - k, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - /*nr*/ 4, - /*kr*/ 16, - /*sr*/ 2); -} - -template -inline void prepare_weight_data( - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_weights( - weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -template -void kernel( - // Outputs - float32_t* output, - // Inputs - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x4x16_f32_neondot( - output, - output_m_stride, - m, - n, - k, - group_size, - weight_data, - activation_data, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); -} - -} // namespace - // channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot - -namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot { - -inline size_t -activation_data_size(int m, int k, int group_size, bool has_weight_zeros) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - packed_activations_size( - m, - k, - group_size, - has_weight_zeros, - /*mr*/ 1, - /*kr*/ 16, - /*sr*/ 2); -} - -inline void prepare_activation_data( - void* activation_data, - // Inputs - int m, - int k, - // Ignored if has_weight_zeros = false - int group_size, - const float* activations, - bool has_weight_zeros) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_activations( - activation_data, m, k, group_size, activations, has_weight_zeros); -} - -template -size_t weight_data_size( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight::packed_weights_size( - n, - k, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - /*nr*/ 8, - /*kr*/ 16, - /*sr*/ 2); -} - -template -void prepare_weight_data( - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - pack_weights( - weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -template -void kernel( - // Outputs - float32_t* output, - // Inputs - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight:: - kernel_1x8x16_f32_neondot( - output, - output_m_stride, - m, - n, - k, - group_size, - weight_data, - activation_data, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); -} - -} // namespace - // channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot - -} // namespace torchao::kernels::cpu::aarch64::linear - -#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h new file mode 100644 index 0000000000..b83c28143f --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h @@ -0,0 +1,384 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal { + +namespace { +/* +This function loads int8x16_t value from a, and 8 int8x16_t values from b. +For each int8x16_t of b: +- subl to subtarct a_zero_point from a, to get a_low, a_high +- 4 int32x4 accumulated values +- for i in [0, 8]: + - load b[i] + - subl to subtarct b_zero_point from b, to get b_low, b_high + - smlal_lane to multiply a_low[i] and b_low_low. + - smlal_lane to multiply a_low[i] and b_low_high. + - smlal_lane to multiply a_low[i] and b_high_low. + - smlal_lane to multiply a_low[i] and b_high_high. + - This produces 2 int32x4_t values +- for i in [0, 8]: + - load b[i] + - subl to subtarct b_zero_point from b, to get b_low, b_high + - smlal_lane to multiply a_low[i] and b_low_low. + - smlal_lane to multiply a_low[i] and b_low_high. + - smlal_lane to multiply a_low[i] and b_high_low. + - smlal_lane to multiply a_low[i] and b_high_high. + - This produces 2 int32x4_t values +Possibly better to transpose 16x16 of b and use dotprod. Left for future. +*/ + +template +TORCHAO_ALWAYS_INLINE void block_mul_1x16x1( + const int16x4_t& a_vec, + const int8x16_t& b_vec, + const int8x16_t& b_zero_point_vec, + int32x4_t (&partial_sums)[4]) { + int16x8_t b_vec_low = + vsubl_s8(vget_low_s8(b_vec), vget_low_s8(b_zero_point_vec)); + int16x8_t b_vec_high = + vsubl_s8(vget_high_s8(b_vec), vget_high_s8(b_zero_point_vec)); + partial_sums[0] = + vmlal_lane_s16(partial_sums[0], vget_low_s16(b_vec_low), a_vec, lane); + partial_sums[1] = + vmlal_lane_s16(partial_sums[1], vget_high_s16(b_vec_low), a_vec, lane); + partial_sums[2] = + vmlal_lane_s16(partial_sums[2], vget_low_s16(b_vec_high), a_vec, lane); + partial_sums[3] = + vmlal_lane_s16(partial_sums[3], vget_high_s16(b_vec_high), a_vec, lane); +} + +void block_mul_1x16x16( + const int8_t* a, + const int8_t* b, + const size_t ldb, + const int8_t a_zero_point, + const int8_t* b_zero_point, + int32x4_t (&partial_sums)[4]) { + int8x16_t a_vec = vld1q_s8(a); + int8x8_t a_zero_point_vec = vdup_n_s8(a_zero_point); + int8x16_t b_zero_point_vec = vld1q_s8(b_zero_point); + int16x8_t a_vec_low = vsubl_s8(vget_low_s8(a_vec), a_zero_point_vec); + int16x8_t a_vec_high = vsubl_s8(vget_high_s8(a_vec), a_zero_point_vec); + + int8x16_t b_vec = vld1q_s8(b + 0 * ldb); + block_mul_1x16x1<0>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 1 * ldb); + block_mul_1x16x1<1>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 2 * ldb); + block_mul_1x16x1<2>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 3 * ldb); + block_mul_1x16x1<3>( + vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 4 * ldb); + block_mul_1x16x1<0>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 5 * ldb); + block_mul_1x16x1<1>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 6 * ldb); + block_mul_1x16x1<2>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 7 * ldb); + block_mul_1x16x1<3>( + vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); + + // Second set of 8 channels + b_vec = vld1q_s8(b + 8 * ldb); + block_mul_1x16x1<0>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 9 * ldb); + block_mul_1x16x1<1>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 10 * ldb); + block_mul_1x16x1<2>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 11 * ldb); + block_mul_1x16x1<3>( + vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 12 * ldb); + block_mul_1x16x1<0>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 13 * ldb); + block_mul_1x16x1<1>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 14 * ldb); + block_mul_1x16x1<2>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); + b_vec = vld1q_s8(b + 15 * ldb); + block_mul_1x16x1<3>( + vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); +} + +TORCHAO_ALWAYS_INLINE void dequantize_1x16_int32_t( + const int32x4_t (&sums)[4], + const float* lhs_scales, + const float* rhs_scales, + float32x4_t (&outputs)[4]) { + float32x4_t scales_0123 = vmulq_n_f32(vld1q_f32(rhs_scales), lhs_scales[0]); + float32x4_t scales_4567 = + vmulq_n_f32(vld1q_f32(rhs_scales + 4), lhs_scales[0]); + float32x4_t scales_89ab = + vmulq_n_f32(vld1q_f32(rhs_scales + 8), lhs_scales[0]); + float32x4_t scales_cdef = + vmulq_n_f32(vld1q_f32(rhs_scales + 12), lhs_scales[0]); + + outputs[0] = vmulq_f32(vcvtq_f32_s32(sums[0]), scales_0123); + outputs[1] = vmulq_f32(vcvtq_f32_s32(sums[1]), scales_4567); + outputs[2] = vmulq_f32(vcvtq_f32_s32(sums[2]), scales_89ab); + outputs[3] = vmulq_f32(vcvtq_f32_s32(sums[3]), scales_cdef); +} + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template <> +struct KernelImpl { + /** + * @brief Implements quantized matrix multiplication for 8-bit channelwise + * quantized matrices + * + * This specialized implementation handles the case where: + * - Both LHS and RHS have zero points (true, true) + * - Neither LHS nor RHS are transposed (false, false) + * + * The function performs a quantized matrix multiplication C = A * B where: + * - A is an m×k matrix (LHS) + * - B is a k×n matrix (RHS) + * - C is an m×n matrix (output) + * + * The implementation uses NEON intrinsics for vectorized computation and + * processes data in blocks of 16×16 for optimal performance on ARM + * architecture. + * + * @param m Number of rows in LHS and output + * @param n Number of columns in RHS and output + * @param k Number of columns in LHS and rows in RHS + * @param lhs Pointer to LHS matrix data (int8_t) + * @param lhs_stride_m Stride between rows of LHS + * @param rhs Pointer to RHS matrix data (int8_t) + * @param rhs_stride_n Stride between rows of RHS + * @param output Pointer to output matrix (float32_t) + * @param out_stride_m Stride between rows of output + * @param lhs_zero_points Zero points for LHS quantization (per-channel) + * @param rhs_zero_points Zero points for RHS quantization (per-channel) + * @param lhs_scales Scales for LHS quantization (per-channel) + * @param rhs_scales Scales for RHS quantization (per-channel) + * @param lhs_qparams_stride Stride for LHS quantization parameters + * @param rhs_qparams_stride Stride for RHS quantization parameters + */ + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + // If lhs_zero_points and rhs_zero_points are not contiguous, transpose + std::unique_ptr lhs_zero_points_transposed = + std::make_unique(m); + std::unique_ptr lhs_scales_transposed = + std::make_unique(m); + if (lhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + lhs_zero_points, + lhs_scales, + lhs_zero_points_transposed.get(), + lhs_scales_transposed.get(), + m, + lhs_qparams_stride); + lhs_zero_points = lhs_zero_points_transposed.get(); + lhs_scales = lhs_scales_transposed.get(); + } + std::unique_ptr rhs_zero_points_transposed = + std::make_unique(n); + std::unique_ptr rhs_scales_transposed = + std::make_unique(n); + if (rhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + rhs_zero_points, + rhs_scales, + rhs_zero_points_transposed.get(), + rhs_scales_transposed.get(), + n, + rhs_qparams_stride); + rhs_zero_points = rhs_zero_points_transposed.get(); + rhs_scales = rhs_scales_transposed.get(); + } + + for (int m_idx = 0; m_idx < m; m_idx++) { + // Loop over 16 cols at a time + // Access to partial tiles must be protected:w + constexpr int nr = 16; + constexpr int kr = 16; + assert(n >= nr); + for (int n_idx = 0; n_idx < n; n_idx += nr) { + // If remaining is < nr, that must mean that (nr - remaining) items + // dont need to be computed. + // In order to avoid out-of-bounds access, we need to rewind n_indx a + // bit + // |-------------------|-------------------| + // 0-------------------8-------------------16 + // 0-------------------8-----10 + // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to + // 8 - (8 - 10) = 2 + int remaining = std::min(n - n_idx, nr); + n_idx = n_idx - (nr - remaining); + // Set activation_ptr to start of activation qvals for row m_idx + const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m; + const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx; + int32x4_t int32_sums[nr / 4] = {vdupq_n_s32(0)}; + + // Loop k_idx by group + int k_idx = 0; + for (; (k_idx + kr) <= k; k_idx += kr) { + block_mul_1x16x16( + lhs_ptr, + rhs_ptr, + rhs_stride_n, + lhs_zero_points[m_idx], + rhs_zero_points + n_idx, + int32_sums); + lhs_ptr += kr; + rhs_ptr += kr * rhs_stride_n; + } + + int8x16_t b_zero_point_vec = vld1q_s8(rhs_zero_points + n_idx); + for (int ki = 0; ki < (k - k_idx); ++ki) { + // For each of the remaining k values + // Load 1 int8_t from lhs + // Load 16 int8_t from rhs + // And multiply + add into the 16 accumulators + // arranged as int32x4_t[4] + int16_t a_val = static_cast(lhs_ptr[ki]) - + static_cast(lhs_zero_points[m_idx]); + int8x16_t b_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n); + int16x8_t b_vec_low = + vsubl_s8(vget_low_s8(b_vec), vget_low_s8(b_zero_point_vec)); + int16x8_t b_vec_high = + vsubl_s8(vget_high_s8(b_vec), vget_high_s8(b_zero_point_vec)); + int32_sums[0] = + vmlal_n_s16(int32_sums[0], vget_low_s16(b_vec_low), a_val); + int32_sums[1] = + vmlal_n_s16(int32_sums[1], vget_high_s16(b_vec_low), a_val); + int32_sums[2] = + vmlal_n_s16(int32_sums[2], vget_low_s16(b_vec_high), a_val); + int32_sums[3] = + vmlal_n_s16(int32_sums[3], vget_high_s16(b_vec_high), a_val); + } + + float32x4_t res[4]; + dequantize_1x16_int32_t( + int32_sums, lhs_scales + m_idx, rhs_scales + n_idx, res); + + // Store result + // Because we adjust n_idx, we may end up writing the same location + // twice + float* store_loc = output + m_idx * out_stride_m + n_idx; + vst1q_f32(store_loc, res[0]); + vst1q_f32(store_loc + 4, res[1]); + vst1q_f32(store_loc + 8, res[2]); + vst1q_f32(store_loc + 12, res[3]); + } // n_idx + } // m_idx + } +}; + +} // namespace + +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal + +namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h new file mode 100644 index 0000000000..123b7723e4 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h @@ -0,0 +1,336 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal { + +/* +This function loads int8x16_t value from a, and 8 int8x16_t values from b, and +computes 8 dot products, resulting in 8 int32x4_t values. +Furthermore the int8x16_t values from a are reduced via summing, resulting in +int32_t row_sum_a. Similar int8x16_t values from b are reduced via summing, +resulting in int32_t row_sum_b. +*/ +TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16( + const int8_t* a, + const int8_t* b, + const size_t ldb, + int32x4_t (&partial_sums)[8], + int32_t& row_sum_a, + int32_t (&row_sum_b)[8]) { + int8x16_t a_vec = vld1q_s8(a); + row_sum_a = row_sum_a + vaddlvq_s8(a_vec); + +// godbolt (https://godbolt.org/z/9vbq1d1qY) shows this loops doesnt quantize +// get optimized by moving all the loads up in the unrolled loop. Just hoping +// OOO machine will take care of things Late replace this with macros so as to +// deconstruct the loop and do manual optimization. Or just write assembly. +#pragma unroll(8) + for (int i = 0; i < 8; ++i) { + int8x16_t b_vec = vld1q_s8(b + i * ldb); + row_sum_b[i] = row_sum_b[i] + vaddlvq_s8(b_vec); + partial_sums[i] = vdotq_s32(partial_sums[i], a_vec, b_vec); + } +} + +TORCHAO_ALWAYS_INLINE static void reduce_1x8_int32x4_t_sums( + const int32x4_t (&partial_sums)[8], + int32_t (&sums)[8]) { +#pragma unroll(8) + for (int i = 0; i < 8; ++i) { + sums[i] = vaddvq_s32(partial_sums[i]); + } +} + +TORCHAO_ALWAYS_INLINE static void dequantize_1x8_int32_t( + const int32_t (&sums)[8], + int32_t& row_sum_lhs, + int32_t (&row_sum_rhs)[8], + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int32_t k, + float32x4x2_t& outputs) { + int32x4_t vec_sum_0123 = vld1q_s32(sums); + int32x4_t vec_sum_4567 = vld1q_s32(sums + 4); + + int32x4_t row_sum_rhs_x_lhs_zp_0123 = + vmulq_n_s32(vld1q_s32(row_sum_rhs), (int32_t)lhs_zero_points[0]); + int32x4_t row_sum_rhs_x_lhs_zp_4567 = + vmulq_n_s32(vld1q_s32(row_sum_rhs + 4), (int32_t)lhs_zero_points[0]); + + // Extract rhs zero point in int8x8_t and convert to int32x4_t + int16x8_t rhs_zero_points_vec_01234567 = vmovl_s8(vld1_s8(rhs_zero_points)); + int32x4_t rhs_zero_points_vec_0123 = + vmovl_s16(vget_low_s16(rhs_zero_points_vec_01234567)); + int32x4_t rhs_zero_points_vec_4567 = + vmovl_s16(vget_high_s16(rhs_zero_points_vec_01234567)); + int32x4_t row_sum_lhs_x_rhs_zp_0123 = + vmulq_n_s32(rhs_zero_points_vec_0123, row_sum_lhs); + int32x4_t row_sum_lhs_x_rhs_zp_4567 = + vmulq_n_s32(rhs_zero_points_vec_4567, row_sum_lhs); + + int32x4_t zp_rhs_x_zp_lhs_0123 = + vmulq_n_s32(rhs_zero_points_vec_0123, k * (int32_t)lhs_zero_points[0]); + int32x4_t zp_rhs_x_zp_lhs_4567 = + vmulq_n_s32(rhs_zero_points_vec_4567, k * (int32_t)lhs_zero_points[0]); + + vec_sum_0123 = vsubq_s32(vec_sum_0123, row_sum_rhs_x_lhs_zp_0123); + vec_sum_0123 = vsubq_s32(vec_sum_0123, row_sum_lhs_x_rhs_zp_0123); + vec_sum_0123 = vaddq_s32(vec_sum_0123, zp_rhs_x_zp_lhs_0123); + + vec_sum_4567 = vsubq_s32(vec_sum_4567, row_sum_rhs_x_lhs_zp_4567); + vec_sum_4567 = vsubq_s32(vec_sum_4567, row_sum_lhs_x_rhs_zp_4567); + vec_sum_4567 = vaddq_s32(vec_sum_4567, zp_rhs_x_zp_lhs_4567); + + float32x4_t scales_0123 = vmulq_n_f32(vld1q_f32(rhs_scales), lhs_scales[0]); + float32x4_t scales_4567 = + vmulq_n_f32(vld1q_f32(rhs_scales + 4), lhs_scales[0]); + + outputs.val[0] = vmulq_f32(vcvtq_f32_s32(vec_sum_0123), scales_0123); + outputs.val[1] = vmulq_f32(vcvtq_f32_s32(vec_sum_4567), scales_4567); +} + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template <> +struct KernelImpl { + /** + * @brief Executes a quantized matrix multiplication with channelwise + * quantization parameters + * + * This function performs matrix multiplication between two 8-bit quantized + * matrices with per-channel quantization parameters. It handles the following + * operations: + * 1. Transposes quantization parameters if they're not contiguous + * 2. Processes the matrices in blocks of 8 columns at a time + * 3. Uses NEON dot product instructions for efficient computation + * 4. Handles edge cases for remaining elements + * 5. Dequantizes the results to floating point + * + * @param m Number of rows in the output matrix + * @param n Number of columns in the output matrix + * @param k Number of columns in lhs / rows in rhs + * @param lhs Pointer to the left-hand side matrix (quantized int8) + * @param lhs_stride_m Stride between rows of the lhs matrix + * @param rhs Pointer to the right-hand side matrix (quantized int8) + * @param rhs_stride_n Stride between rows of the rhs matrix. Expects matrix + * to be transposed. Thus of size [n x k] + * @param output Pointer to the output matrix (float32) + * @param out_stride_m Stride between rows of the output matrix + * @param lhs_zero_points Zero points for lhs quantization (per-channel) + * @param rhs_zero_points Zero points for rhs quantization (per-channel) + * @param lhs_scales Scales for lhs quantization (per-channel) + * @param rhs_scales Scales for rhs quantization (per-channel) + * @param lhs_qparams_stride Stride for lhs quantization parameters + * @param rhs_qparams_stride Stride for rhs quantization parameters + */ + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + // If lhs_zero_points and rhs_zero_points are not contiguous, transpose + std::unique_ptr lhs_zero_points_transposed = + std::make_unique(m); + std::unique_ptr lhs_scales_transposed = + std::make_unique(m); + if (lhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + lhs_zero_points, + lhs_scales, + lhs_zero_points_transposed.get(), + lhs_scales_transposed.get(), + m, + lhs_qparams_stride); + lhs_zero_points = lhs_zero_points_transposed.get(); + lhs_scales = lhs_scales_transposed.get(); + } + std::unique_ptr rhs_zero_points_transposed = + std::make_unique(n); + std::unique_ptr rhs_scales_transposed = + std::make_unique(n); + if (rhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + rhs_zero_points, + rhs_scales, + rhs_zero_points_transposed.get(), + rhs_scales_transposed.get(), + n, + rhs_qparams_stride); + rhs_zero_points = rhs_zero_points_transposed.get(); + rhs_scales = rhs_scales_transposed.get(); + } + + for (int m_idx = 0; m_idx < m; m_idx++) { + // Loop over 8 cols at a time + // Access to partial tiles must be protected:w + constexpr int nr = 8; + constexpr int kr = 16; + assert(n >= nr); + for (int n_idx = 0; n_idx < n; n_idx += nr) { + // If remaining is < nr, that must mean that (nr - remaining) items + // dont need to be computed. + // In order to avoid out-of-bounds access, we need to rewind n_indx a + // bit + // |-------------------|-------------------| + // 0-------------------8-------------------16 + // 0-------------------8-----10 + // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to + // 8 - (8 - 10) = 2 + int remaining = std::min(n - n_idx, nr); + n_idx = n_idx - (nr - remaining); + // Set activation_ptr to start of activation qvals for row m_idx + const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m; + const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx * rhs_stride_n; + int32x4_t int32_sums[nr] = {vdupq_n_s32(0)}; + int32_t row_sum_lhs = 0; + int32_t row_sum_rhs[nr] = {0, 0, 0, 0, 0, 0, 0, 0}; + int32_t sums[nr]; + + // Loop k_idx by group + int k_idx = 0; + for (; (k_idx + kr) <= k; k_idx += kr) { + block_mul_1x8x16( + lhs_ptr, + rhs_ptr, + rhs_stride_n, + int32_sums, + row_sum_lhs, + row_sum_rhs); + lhs_ptr += kr; + rhs_ptr += kr; + } + + reduce_1x8_int32x4_t_sums(int32_sums, sums); + for (int ki = 0; ki < (k - k_idx); ++ki) { + row_sum_lhs += (int32_t)lhs_ptr[ki]; + } + for (int ni = 0; ni < nr; ++ni) { + for (int ki = 0; ki < (k - k_idx); ++ki) { + sums[ni] += (int32_t)lhs_ptr[ki] * + (int32_t)(rhs_ptr + ni * rhs_stride_n)[ki]; + row_sum_rhs[ni] += (int32_t)(rhs_ptr + ni * rhs_stride_n)[ki]; + } + } + + float32x4x2_t res; + dequantize_1x8_int32_t( + sums, + row_sum_lhs, + row_sum_rhs, + lhs_zero_points + m_idx, + rhs_zero_points + n_idx, + lhs_scales + m_idx, + rhs_scales + n_idx, + k, + res); + + // Store result + // Because we adjust n_idx, we may end up writing the same location + // twice + float* store_loc = output + m_idx * out_stride_m + n_idx; + vst1q_f32(store_loc, res.val[0]); + vst1q_f32(store_loc + 4, res.val[1]); + } // n_idx + } // m_idx + } +}; + +} // namespace + // channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal + +namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h new file mode 100644 index 0000000000..bdad1b4a47 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h @@ -0,0 +1,281 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal { + +namespace { + +/* +This function loads float32x4_t value from a, and 16 int8x16_t values from b. +For each int8x16_t of b: +- 4 float32x4 accumulated values +- load 4 a in float32x4_t +- [The following repeats for each of the 4 lanes of a] +- for i in [0, 4]: + - load b[i] in int8x16_t + - subl to subtract b_zero_point from b, to get b_low, b_high + - vmovl to get b_low_low, b_low_high, b_high_low, b_high_high + - vcvtq to convert to float32x4_t, we will have 4 of these. +- for i in [0, 4]: for each of the 4 float32x4_t of b: + - vfmaq_lane_fp32 to multiply a[lane] and b[i] + - vfmaq_lane_fp32 to multiply a[lane] and b[i] + - vfmaq_lane_fp32 to multiply a[lane] and b[i] + - vfmaq_lane_fp32 to multiply a[lane] and b[i] +- By doing the above 4 times (lane=[0-3]), we used all values along k dim of a + and accumulated 4 float32x4_t values +*/ +TORCHAO_ALWAYS_INLINE void block_mul_1x16x1( + const float32_t a, + const int8x16_t& b_vec, + const int8_t b_zero_point, + const float b_scale, + float32x4_t (&partial_sums)[4]) { + int8x8_t b_zero_point_vec = vdup_n_s8(b_zero_point); + int16x8_t b_vec_low = vsubl_s8(vget_low_s8(b_vec), b_zero_point_vec); + int16x8_t b_vec_high = vsubl_s8(vget_high_s8(b_vec), b_zero_point_vec); + float32x4_t b_vec_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_low))); + float32x4_t b_vec_low_high = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_low))); + float32x4_t b_vec_high_low = + vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_high))); + float32x4_t b_vec_high_high = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_high))); + b_vec_low_low = vmulq_n_f32(b_vec_low_low, b_scale); + b_vec_low_high = vmulq_n_f32(b_vec_low_high, b_scale); + b_vec_high_low = vmulq_n_f32(b_vec_high_low, b_scale); + b_vec_high_high = vmulq_n_f32(b_vec_high_high, b_scale); + + partial_sums[0] = vfmaq_n_f32(partial_sums[0], b_vec_low_low, a); + partial_sums[1] = vfmaq_n_f32(partial_sums[1], b_vec_low_high, a); + partial_sums[2] = vfmaq_n_f32(partial_sums[2], b_vec_high_low, a); + partial_sums[3] = vfmaq_n_f32(partial_sums[3], b_vec_high_high, a); +} + +void block_mul_1x16x4( + const float32_t* a, + const int8_t* b, + const size_t ldb, + const int8_t* b_zero_point, + const float* b_scale, + float32x4_t (&partial_sums)[4]) { + #pragma unroll(8) + for (int i = 0; i < 4; i++) { + int8x16_t b_vec = vld1q_s8(b + i * ldb); + block_mul_1x16x1(a[i], b_vec, b_zero_point[i], b_scale[i], partial_sums); + } +} + +} // namespace + +template +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride); +}; + +/* +Document param meaning +rhs_stride_n: Since rhs transposed == false, the expected shape of rhs is k x n. +Thus rhs_stride_n is the stride of k dim, that how many bytes aparts elements +in k dim are. +*/ +template <> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + std::unique_ptr rhs_zero_points_transposed = std::make_unique(k); + std::unique_ptr rhs_scales_transposed = std::make_unique(k); + if (rhs_qparams_stride > 1) { + utils::transpose_scales_and_zero_points( + rhs_zero_points, + rhs_scales, + rhs_zero_points_transposed.get(), + rhs_scales_transposed.get(), + k, + rhs_qparams_stride); + rhs_zero_points = rhs_zero_points_transposed.get(); + rhs_scales = rhs_scales_transposed.get(); + } + + constexpr int nr = 16; + constexpr int kr = 4; + for (int m_idx = 0; m_idx < m; m_idx++) { + // Loop over 16 cols at a time + // Access to partial tiles must be protected:w + assert(n >= nr); + for (int n_idx = 0; n_idx < n; n_idx += nr) { + // If remaining is < nr, that must mean that (nr - remaining) items + // dont need to be computed. + // In order to avoid out-of-bounds access, we need to rewind n_indx a + // bit + // |-------------------|-------------------| + // 0-------------------8-------------------16 + // 0-------------------8-----10 + // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to + // 8 - (8 - 10) = 2 + int remaining = std::min(n - n_idx, nr); + n_idx = n_idx - (nr - remaining); + // Set activation_ptr to start of activation qvals for row m_idx + const float* lhs_ptr = lhs + m_idx * lhs_stride_m; + const int8_t* rhs_ptr = rhs + n_idx; + float32x4_t sums[nr / 4] = {vdupq_n_f32(0)}; + + // Loop k_idx by group + int k_idx = 0; + for (; (k_idx + kr) <= k; k_idx += kr) { + block_mul_1x16x4( + lhs_ptr, + rhs_ptr, + rhs_stride_n, + rhs_zero_points + k_idx, + rhs_scales + k_idx, + sums); + lhs_ptr += kr; + rhs_ptr += kr * rhs_stride_n; + } + + for (int ki = 0; ki < (k - k_idx); ++ki) { + // For each of the remaining k values + // Load 1 int8_t from lhs + // Load 16 int8_t from rhs + // And multiply + add into the 16 accumulators + // arranged as int32x4_t[4] + int8x16_t rhs_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n); + block_mul_1x16x1( + lhs_ptr[ki], + rhs_vec, + rhs_zero_points[k_idx + ki], + rhs_scales[k_idx + ki], + sums); + } + + // Store result + // Because we adjust n_idx, we may end up writing the same location + // twice + // Note that the reason this case is being handled only for this kernel + // and not others in this directory is because only for this kernel + // we support accumulation. + float* store_loc = output + m_idx * out_stride_m + n_idx; + if (remaining < 16) { + // If remaining is < 16, then not all of the 16 accumulators are + // valid. That is not all of float32x4_t[4] are valid. We need to + // find the first valid one, and then store the rest of the + // accumulators in the same order. + // First valid one is at 3 - ((remaining - 1) / 4) because: + // If remaining is say 10 then first 6 are not valid. + // Thus first group of 4 at sums[0] is not valid. + // In the second group of 4, the first 2 are not valid. + // Rest are valid. + int start_sum_idx = 3 - ((remaining - 1) / 4); + // If remaining is 11, then the sums[1] has 3 valid values + // so 3 - (11 -1) % 4 = 3 - 10 % 4 = 3 - 2 = 1 + // Thus there is 1 invalid value in the first group of 4 + int invalid_values_in_32x4_reg = 3 - (remaining - 1) % 4; + store_loc += start_sum_idx * 4; + store_loc += invalid_values_in_32x4_reg; + if (invalid_values_in_32x4_reg > 0) { + for (int val_idx = invalid_values_in_32x4_reg; val_idx < 4; + ++val_idx) { + *store_loc = sums[start_sum_idx][val_idx] + (*store_loc) * beta; + store_loc += 1; + } + start_sum_idx++; + } + for (int out_idx = 0, sum_idx = start_sum_idx; sum_idx < nr / 4; + out_idx += 4, ++sum_idx) { + float32x4_t sum_val = vld1q_f32(store_loc + out_idx); + sums[sum_idx] = vfmaq_n_f32(sums[sum_idx], sum_val, beta); + vst1q_f32(store_loc + out_idx, sums[sum_idx]); + } + } else { + for (int out_idx = 0, sum_idx = 0; out_idx < nr; + out_idx += 4, ++sum_idx) { + float32x4_t sum_val = vld1q_f32(store_loc + out_idx); + sums[sum_idx] = vfmaq_n_f32(sums[sum_idx], sum_val, beta); + vst1q_f32(store_loc + out_idx, sums[sum_idx]); + } + } + } // n_idx + } // m_idx + } +}; + +} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal + +namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 { +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + torchao::kernels::cpu::aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + rhs_zero_points, + rhs_scales, + beta, + rhs_qparams_stride); +} +} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h new file mode 100644 index 0000000000..43f3dd4bce --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h @@ -0,0 +1,95 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// TODO: this file will be deleted and replaced by +// torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h +// It exists now to prevent breaking existing code in the interim. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); + +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot + +namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); + +} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal + +namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 { + +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float32_t* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride); + +} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#include +#include +#include + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h new file mode 100644 index 0000000000..68ab912705 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h @@ -0,0 +1,70 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include + +namespace torchao::kernels::cpu::aarch64::quantized_matmul { +namespace utils { + +TORCHAO_ALWAYS_INLINE static void transpose_scales_and_zero_points( + const int8_t* zero_points, + const float* scales, + int8_t* zero_points_transposed, + float* scales_transposed, + const int m, + const int stride_m) { + // Process 8 elements at a time using NEON + int i = 0; + for (; i + 8 <= m; i += 8) { + // Load 8 zero points with stride_m + int8x8_t zp = { + zero_points[0 * stride_m], + zero_points[1 * stride_m], + zero_points[2 * stride_m], + zero_points[3 * stride_m], + zero_points[4 * stride_m], + zero_points[5 * stride_m], + zero_points[6 * stride_m], + zero_points[7 * stride_m]}; + zero_points += 8 * stride_m; + // Store contiguously + vst1_s8(zero_points_transposed + i, zp); + + // Load 8 scales with stride_m + float32x4_t scales_lo = { + scales[0 * stride_m], + scales[1 * stride_m], + scales[2 * stride_m], + scales[3 * stride_m]}; + float32x4_t scales_hi = { + scales[4 * stride_m], + scales[5 * stride_m], + scales[6 * stride_m], + scales[7 * stride_m]}; + scales += 8 * stride_m; + // Store contiguously + vst1q_f32(scales_transposed + i, scales_lo); + vst1q_f32(scales_transposed + i + 4, scales_hi); + } + + // Handle remaining elements + for (; i < m; i++) { + zero_points_transposed[i] = zero_points[0]; + scales_transposed[i] = scales[0]; + zero_points += stride_m; + scales += stride_m; + } +} + +} // namespace utils +} // namespace torchao::kernels::cpu::aarch64::quantized_matmul + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp index 65416fdf1d..3460d67fba 100644 --- a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include void torchao::quantization::get_qvals_range( @@ -64,8 +65,6 @@ void torchao::kernels::cpu::aarch64::quantization::quantize( int8_t zero, int8_t qmin, int8_t qmax) { - assert(size % 8 == 0); - float32_t invScale = 1.0 / (scale + 1e-16); float32x4_t vec_zero = vdupq_n_f32(zero); float32x4_t vec_invScale = vdupq_n_f32(invScale); @@ -78,7 +77,8 @@ void torchao::kernels::cpu::aarch64::quantization::quantize( int16x4_t vec_qval_s16_0; int16x4_t vec_qval_s16_1; - for (int i = 0; i < size; i += 8) { + int i = 0; + for (; (i + 8) < size; i += 8) { ////////////////////////////////////// // Quantize first 4 element chunk to int16 ////////////////////////////////////// @@ -112,6 +112,23 @@ void torchao::kernels::cpu::aarch64::quantization::quantize( int8x8_t vec_qval_s8_01 = vqmovn_s16(vec_qval_s16_01); vst1_s8(qvals + i, vec_qval_s8_01); } + auto curr_rounding_mode = fegetround(); + fesetround(FE_TONEAREST); + for (; i < size; ++i) { + // Quantize remaining elements using scalar code + float32_t val = vals[i]; + float32_t qval_f32 = zero + val * invScale; + int32_t qval_s32 = static_cast(std::nearbyint(qval_f32)); + + // Clip to qmin and qmax + qval_s32 = std::max( + static_cast(qmin), + std::min(qval_s32, static_cast(qmax))); + + // Store the quantized value + qvals[i] = static_cast(qval_s32); + } + fesetround(int(curr_rounding_mode)); } #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 5b6ba2ab98..db736d84a3 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -42,6 +42,7 @@ add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 $ if(TORCHAO_BUILD_KLEIDIAI) add_compile_definitions(TORCHAO_ENABLE_KLEIDI) + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) endif() if(TORCHAO_BUILD_ARM_I8MM) @@ -119,6 +120,14 @@ target_link_libraries( dep ) +add_executable(test_qmatmul test_qmatmul.cpp) +target_link_libraries( + test_qmatmul + PRIVATE + GTest::gtest_main + dep +) + include(GoogleTest) gtest_discover_tests(test_quantization) gtest_discover_tests(test_reduction) @@ -127,3 +136,4 @@ gtest_discover_tests(test_linear) gtest_discover_tests(test_valpacking) gtest_discover_tests(test_embedding) gtest_discover_tests(test_weight_packing) +gtest_discover_tests(test_qmatmul) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 4b2181d7cc..1898e8b535 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -61,3 +61,4 @@ ${CMAKE_OUT}/test_linear ${CMAKE_OUT}/test_valpacking ${CMAKE_OUT}/test_embedding ${CMAKE_OUT}/test_weight_packing +${CMAKE_OUT}/test_qmatmul diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 2e19a524e5..671ee3f0b9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -12,17 +12,23 @@ #include #include #include -#include #include float kTol = 0.0001; -template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot( +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32( int m, int k, int n, - int group_size) { + int group_size, + bool has_bias, + bool has_clamp) { + constexpr int mr = 1; + constexpr int nr = 1; + constexpr int kr = 32; + constexpr int sr = 1; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -35,48 +41,46 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; - std::vector activation_data( - activation_data_size(m, k, group_size, has_weight_zeros)); - prepare_activation_data( - (void*)activation_data.data(), + std::vector packed_activations( + packed_activations_size(m, k, group_size, has_weight_zeros, mr, kr, sr)); + pack_activations( + (void*)packed_activations.data(), m, k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); - std::vector weight_data(weight_data_size( - n, k, group_size, has_weight_zeros, has_bias)); - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - prepare_weight_data( - (void*)weight_data.data(), + std::vector packed_weights(packed_weights_size( + n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr)); + pack_weights( + (void*)packed_weights.data(), n, k, group_size, test_case.weight_qvals.data(), test_case.weight_scales.data(), - /*weight_zeros=*/weight_zeros_ptr, - bias_ptr); + has_weight_zeros ? test_case.weight_zeros.data() : nullptr, + has_bias ? test_case.bias.data() : nullptr, + nr, + kr, + sr); std::vector output(m * n); - kernel( + kernel_1x1x32_f32_neondot( output.data(), /*output_m_stride=*/n, m, n, k, group_size, - weight_data.data(), - activation_data.data(), + packed_weights.data(), + packed_activations.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max, has_weight_zeros, @@ -88,56 +92,19 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - Standard) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - HasWeightZeros) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - HasBias) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, - HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); -} - -template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot( +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16( int m, int k, int n, - int group_size) { + int group_size, + bool has_bias, + bool has_clamp) { + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 16; + constexpr int sr = 2; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -150,48 +117,46 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; - std::vector activation_data( - activation_data_size(m, k, group_size, has_weight_zeros)); - prepare_activation_data( - (void*)activation_data.data(), + std::vector packed_activations( + packed_activations_size(m, k, group_size, has_weight_zeros, mr, kr, sr)); + pack_activations( + (void*)packed_activations.data(), m, k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); - std::vector weight_data(weight_data_size( - n, k, group_size, has_weight_zeros, has_bias)); - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - prepare_weight_data( - (void*)weight_data.data(), + std::vector packed_weights(packed_weights_size( + n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr)); + pack_weights( + (void*)packed_weights.data(), n, k, group_size, test_case.weight_qvals.data(), test_case.weight_scales.data(), - /*weight_zeros=*/weight_zeros_ptr, - bias_ptr); + has_weight_zeros ? test_case.weight_zeros.data() : nullptr, + has_bias ? test_case.bias.data() : nullptr, + nr, + kr, + sr); std::vector output(m * n); - kernel( + kernel_1x4x16_f32_neondot( output.data(), /*output_m_stride=*/n, m, n, k, group_size, - weight_data.data(), - activation_data.data(), + packed_weights.data(), + packed_activations.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max, has_weight_zeros, @@ -203,69 +168,19 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - Standard) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - HasWeightZeros) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - HasBias) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} - -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - NLessThan4) { - for (int n = 1; n < 4; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); - } -} - -template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16( int m, int k, int n, - int group_size) { + int group_size, + bool has_bias, + bool has_clamp) { + constexpr int mr = 1; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -278,48 +193,46 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; - std::vector activation_data( - activation_data_size(m, k, group_size, has_weight_zeros)); - prepare_activation_data( - (void*)activation_data.data(), + std::vector packed_activations( + packed_activations_size(m, k, group_size, has_weight_zeros, mr, kr, sr)); + pack_activations( + (void*)packed_activations.data(), m, k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); - std::vector weight_data(weight_data_size( - n, k, group_size, has_weight_zeros, has_bias)); - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - prepare_weight_data( - (void*)weight_data.data(), + std::vector packed_weights(packed_weights_size( + n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr)); + pack_weights( + (void*)packed_weights.data(), n, k, group_size, test_case.weight_qvals.data(), test_case.weight_scales.data(), - /*weight_zeros=*/weight_zeros_ptr, - bias_ptr); + has_weight_zeros ? test_case.weight_zeros.data() : nullptr, + has_bias ? test_case.bias.data() : nullptr, + nr, + kr, + sr); std::vector output(m * n); - kernel( + kernel_1x8x16_f32_neondot( output.data(), /*output_m_stride=*/n, m, n, k, group_size, - weight_data.data(), - activation_data.data(), + packed_weights.data(), + packed_activations.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max, has_weight_zeros, @@ -331,70 +244,182 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - Standard) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} +TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, tile_1x1x32) { + constexpr int weight_nbit = 4; -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - HasWeightZeros) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} + // Standard + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/false, + /*has_clamp=*/false); -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - HasBias) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + // With weight zeros + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/true>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With bias + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/true, + /*has_clamp=*/false); + + // With clamp + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/32, + /*has_bias=*/false, + /*has_clamp=*/true); } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); +TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, tile_1x4x16) { + constexpr int weight_nbit = 4; + + // Standard + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With weight zeros + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/true>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With bias + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/true, + /*has_clamp=*/false); + + // With clamp + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/true); + + // n less than 4 + for (int n = 1; n < 4; n++) { + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/n, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + } } -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - NLessThan8) { +TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, tile_1x8x16) { + constexpr int weight_nbit = 4; + + // Standard + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With weight zeros + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/true>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); + + // With bias + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/true, + /*has_clamp=*/false); + + // With clamp + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/13, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/true); + + // n less than 8 for (int n = 1; n < 8; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16< + weight_nbit, + /*has_weight_zeros=*/false>( + /*m=*/7, + /*k=*/64, + /*n=*/n, + /*group_size=*/16, + /*has_bias=*/false, + /*has_clamp=*/false); } } -template +template void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( int m, int k, int n, int group_size, - bool has_weight_zeros, bool has_bias, bool has_clamp) { constexpr int mr = 1; @@ -424,7 +449,10 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( k, group_size, test_case.activations.data(), - has_weight_zeros); + has_weight_zeros, + mr, + kr, + sr); // Define equivalent LUT for affine quantization constexpr int lut_size = (1 << weight_nbit); @@ -453,7 +481,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( has_bias ? test_case.bias.data() : nullptr); std::vector output(m * n); - kernel_1x8x16_f32_neondot( + kernel_1x8x16_f32_neondot( output.data(), /*output_m_stride=*/n, m, @@ -476,85 +504,90 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( TEST(test_channelwise_8bit_activation_groupwise_lowbit_weight, LUT) { constexpr int weight_nbit = 4; - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros*/ false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); // has_weight_zeros - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros*/ true>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/true, /*has_bias=*/false, /*has_clamp=*/false); // has_bias - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/true, /*has_clamp=*/false); // has_clamp - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros*/ false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/true); // n less than 8 (nr) for (int n = 1; n < 8; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_lut( + test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< + weight_nbit, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); } // Other bitwidths test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< - /*weight_nbit*/ 1>( + /*weight_nbit*/ 1, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< - /*weight_nbit*/ 2>( + /*weight_nbit*/ 2, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); test_channelwise_8bit_activation_groupwise_lowbit_weight_lut< - /*weight_nbit*/ 3>( + /*weight_nbit*/ 3, + /*has_weight_zeros=*/false>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16, - /*has_weight_zeros=*/false, /*has_bias=*/false, /*has_clamp=*/false); } diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp new file mode 100644 index 0000000000..ff4f915b2d --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -0,0 +1,600 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include + +#include +#include +#include +#include + +float kTol = 0.0001; + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct test_channelwise_8bit_channelwise_8bit_b { + static void Run(int m, int k, int n); +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + true> { + static void Run(int m, int k, int n, int stride = 1) { + auto test_case = + torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case:: + generate(m, k, n, a_has_zeros, a_has_zeros, false, true, stride); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + k * stride /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + k * stride /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + false> { + static void Run(int m, int k, int n, int stride = 1) { + // TODO: make use of stride for this kernel + auto test_case = + torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case:: + generate(m, k, n, a_has_zeros, a_has_zeros, false, false); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + k /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + n /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +TEST(test_channelwise_8bit_channelwise_8bit_b, TransposedBWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, TransposeBWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16, 5); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19, 10); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TransposedBWithZeroPointsOddSizes2Strided2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/3, /*k=*/64, /*n=*/24, 7); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, NoTransposedWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + NoTransposedWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + NoTransposedWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + NoTransposedWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + false /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +class FP32A_QuantizedB_FP32C_Test : public ::testing::TestWithParam { + public: + int m; + int k; + int n; + int stride; + + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector init_output; + std::vector expected_output; + + std::vector lhs; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + void generate( + int m_, + int k_, + int n_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + int stride_ = 1) { + // Here stride is only applicable to rhs + // and it means that k elements are stride * napart for k x n matrix + // and stride apart for n x k matrix + assert(!lhs_is_transposed_); + assert(rhs_has_zeros_); + m = m_; + k = k_; + n = n_; + stride = stride_; + rhs_has_zeros = rhs_has_zeros_; + lhs_is_transposed = lhs_is_transposed_; + rhs_is_transposed = rhs_is_transposed_; + + assert(!rhs_is_transposed || stride == 1); + + // Generate activations + lhs = torchao::get_random_vector(m * k, -1.0, 1.0); + + // The strange thing this is doing is that instead of quantizing + // each output channel separately, we are quantizing each input channel + // Reason why we do !rhs_is_transposed is because + // we actually want k x n matrix not n x k matrix + // because each input channel is quantized separately + std::tie(rhs, rhs_qvals, rhs_scales, rhs_zeros) = + torchao::test_utils::generate_per_token_quantized_tensor( + k * stride, n, rhs_is_transposed); + + // Compute expected output + init_output = torchao::get_random_vector(m * n, -1.0, 1.0); + + assert(init_output.size() == m * n); + assert(lhs.size() == m * k); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == k * stride); + assert(rhs_zeros.size() == k * stride); + } + + void execute(float beta) { + // Compute expected output + expected_output = init_output; + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx; + if (rhs_is_transposed) { + rhs_idx = n_idx * k * stride + k_idx * stride; + } + float rhs_dequant = rhs_scales[k_idx * stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast(rhs_zeros[k_idx * stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = + expected_output[m_idx * n + n_idx] * beta + res; + } + } + } + + float beta() const { + return GetParam(); + } +}; + +static void test_fp32_a_input_channelwise_8bit_b( + int m, + int k, + int n, + float beta, + FP32A_QuantizedB_FP32C_Test& test_case, + int stride = 1) { + test_case.execute(beta); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32; + + std::vector output(test_case.init_output); + kernel( + m, + n, + k, + test_case.lhs.data(), + k /*lhs_stride_m*/, + test_case.rhs_qvals.data(), + n * stride /*rhs_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.rhs_zeros.data(), + test_case.rhs_scales.data(), + beta, + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPoints) { + generate(1, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/1, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsLargeM) { + generate(4, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes) { + generate(4, 37, 24, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/24, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes2) { + generate(4, 37, 19, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes3) { + generate(4, 27, 21, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/27, /*n=*/21, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsAlpha) { + generate(1, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/1, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsWithStrides) { + stride = 5; + generate(1, 128, 16, true, false, false, stride); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/1, /*k=*/128, /*n=*/16, beta(), *this, stride); +} + +TEST_P(FP32A_QuantizedB_FP32C_Test, BTranposedWithZeroPointsOddSizes2Strides) { + stride = 11; + generate(7, 37, 19, true, false, false, stride); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/7, /*k=*/37, /*n=*/19, beta(), *this, stride); +} + +INSTANTIATE_TEST_SUITE_P( + F32AInt8BFP32CTest, + FP32A_QuantizedB_FP32C_Test, + ::testing::Values(0.0, 1.0, 2.69)); + +static void test_8bit_per_token_q_at_k_matmul_attention( + int b, + int s_q, + int s_k, + int h, + int d, + bool transpose = true) { + auto test_case = torchao:: + channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case:: + generate(b, s_q, s_k, h, d, transpose); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot; + + size_t q_b_stride = test_case.b_q_stride; + size_t q_h_stride = test_case.h_q_stride; + size_t q_s_q_stride = test_case.s_q_stride; + size_t q_scale_zp_b_stride = test_case.b_q_qparams_stride; + size_t q_scale_zp_h_stride = test_case.h_q_qparams_stride; + size_t q_scale_zp_s_stride = test_case.s_q_qparams_stride; + + size_t k_b_stride = test_case.b_k_stride; + size_t k_h_stride = test_case.h_k_stride; + size_t k_s_k_stride = test_case.s_k_stride; + size_t k_scale_zp_b_stride = test_case.b_k_qparams_stride; + size_t k_scale_zp_h_stride = test_case.h_k_qparams_stride; + size_t k_scale_zp_s_stride = test_case.s_k_qparams_stride; + + std::vector output(b * h * s_q * s_k); + size_t output_b_stride = h * s_q * s_k; + size_t output_h_stride = s_q * s_k; + size_t output_s_q_stride = s_k; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + kernel( + s_q, + s_k, + d, + test_case.q_qvals.data() + b_idx * q_b_stride + h_idx * q_h_stride, + q_s_q_stride /*lhs_stride_m*/, + test_case.k_qvals.data() + b_idx * k_b_stride + h_idx * k_h_stride, + k_s_k_stride /*rhs_stride_n*/, + output.data() + b_idx * output_b_stride + h_idx * output_h_stride, + output_s_q_stride /*out_stride_n*/, + test_case.q_zeros.data() + b_idx * q_scale_zp_b_stride + + h_idx * q_scale_zp_h_stride, + test_case.k_zeros.data() + b_idx * k_scale_zp_b_stride + + h_idx * k_scale_zp_h_stride, + test_case.q_scales.data() + b_idx * q_scale_zp_b_stride + + h_idx * q_scale_zp_h_stride, + test_case.k_scales.data() + b_idx * k_scale_zp_b_stride + + h_idx * k_scale_zp_h_stride, + q_scale_zp_s_stride /*lhs qparams stride*/, + k_scale_zp_s_stride /*rhs qparams stride*/); + } + } + + for (int i = 0; i < b * h * s_q * s_k; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, Basic) { + test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndHeadDim) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 33); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndHeadDimDiffSqSk) { + test_8bit_per_token_q_at_k_matmul_attention(1, 7, 16, 7, 33); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndSmallHeadDim) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicNoTransposed) { + test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16, false); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndHeadDimDiffSqSkNoTranspose) { + test_8bit_per_token_q_at_k_matmul_attention(1, 7, 16, 7, 33, false); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndSmallHeadDimNoTranspose) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3, false); +} + +static void test_fp32_attn_scores_at_v_matmul_attention( + int b, + int s_attn, + int s_v, + int h, + int d, + bool transpose_v = true) { + auto test_case = + torchao::fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case::generate( + b, s_attn, s_v, h, d, transpose_v); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32; + + size_t attn_b_stride = test_case.b_attn_stride; + size_t attn_h_stride = test_case.h_attn_stride; + size_t attn_s_q_stride = test_case.s_attn_stride; + + size_t v_b_stride = test_case.b_v_stride; + size_t v_h_stride = test_case.h_v_stride; + size_t v_s_v_stride = test_case.s_v_stride; + size_t v_scale_zp_b_stride = test_case.b_v_qparams_stride; + size_t v_scale_zp_h_stride = test_case.h_v_qparams_stride; + size_t v_scale_zp_s_stride = test_case.s_v_qparams_stride; + + std::vector output(b * s_attn * h * d); + size_t output_b_stride = s_attn * h * d; + size_t output_s_attn_stride = h * d; + size_t output_h_stride = d; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + kernel( + s_attn, + d, + s_v, + test_case.attn_scores.data() + b_idx * attn_b_stride + + h_idx * attn_h_stride, + attn_s_q_stride /*lhs_stride_m*/, + test_case.v_qvals.data() + b_idx * v_b_stride + h_idx * v_h_stride, + v_s_v_stride /*rhs_stride_n*/, + output.data() + b_idx * output_b_stride + h_idx * output_h_stride, + output_s_attn_stride /*out_stride_n*/, + test_case.v_zeros.data() + b_idx * v_scale_zp_b_stride + + h_idx * v_scale_zp_h_stride, + test_case.v_scales.data() + b_idx * v_scale_zp_b_stride + + h_idx * v_scale_zp_h_stride, + 0.0 /*beta*/, + v_scale_zp_s_stride /*rhs qparams stride*/); + } + } + + for (int i = 0; i < b * s_attn * h * d; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, Basic) { + test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndHeadDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 33); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeHeadsAndSmallHeadDim) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, BasicNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 16, 16, 8, 16, false); +} + +TEST( + test_fp32_attn_scores_at_v_matmul_attention, + PrimeHeadsAndSmallHeadDimNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 8, 8, 7, 17, false); +} + +TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDimNoTranspose) { + test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33, false); +} + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index 4720b68fb0..4f96f8bf96 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -84,6 +84,58 @@ inline float get_float_from_bf16(uint16_t bf16) { return f; } +namespace test_utils { +auto generate_per_token_quantized_tensor(int m, int n, bool transposed = false); + +auto generate_per_token_quantized_tensor(int m, int n, bool transposed) { + auto activations = get_random_vector(m * n, -1.0, 1.0); + auto activation_qvals = std::vector(m * n, 0); + auto activation_scales = std::vector(m, 0); + auto activation_zeros = std::vector(m, 0); + + // Quantize activations with 8-bit asymmetric + // TODO: replace with generic function that does not use aarch64 + // quantize method after we combine with torchao + int qmin, qmax, zero; + float vmin, vmax, scale; + torchao::quantization::get_qvals_range( + qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); + for (int m_idx = 0; m_idx < m; m_idx++) { + torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + vmin, vmax, /*vals=*/activations.data() + m_idx * n, /*size=*/n); + torchao::quantization::get_scale_and_zero( + scale, zero, vmin, vmax, qmin, qmax); + activation_scales[m_idx] = scale; + activation_zeros[m_idx] = zero; + torchao::kernels::cpu::aarch64::quantization::quantize( + /*qvals=*/activation_qvals.data() + m_idx * n, + /*vals=*/activations.data() + m_idx * n, + /*size=*/n, + scale, + zero, + qmin, + qmax); + } + if (transposed) { + auto activations_t = std::vector(m * n, 0); + auto activation_qvals_t = std::vector(m * n, 0); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + int activation_idx = m_idx * n + n_idx; + int tranposed_idx = n_idx * m + m_idx; + activations_t[tranposed_idx] = activations[activation_idx]; + activation_qvals_t[tranposed_idx] = activation_qvals[activation_idx]; + } + } + activations = activations_t; + activation_qvals = activation_qvals_t; + } + + return std::make_tuple( + activations, activation_qvals, activation_scales, activation_zeros); +} +} // namespace test_utils + struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { int m; int k; @@ -182,34 +234,8 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { // weights is k x n (stored in column-major) // Generate activations - auto activations = get_random_vector(m * k, -1.0, 1.0); - auto activation_qvals = std::vector(m * k, 0); - auto activation_scales = std::vector(m, 0); - auto activation_zeros = std::vector(m, 0); - - // Quantize activations with 8-bit asymmetric - // TODO: replace with generic function that does not use aarch64 - // quantize method after we combine with torchao - int qmin, qmax, zero; - float vmin, vmax, scale; - torchao::quantization::get_qvals_range( - qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); - for (int m_idx = 0; m_idx < m; m_idx++) { - torchao::kernels::cpu::aarch64::reduction::find_min_and_max( - vmin, vmax, /*vals=*/activations.data() + m_idx * k, /*size=*/k); - torchao::quantization::get_scale_and_zero( - scale, zero, vmin, vmax, qmin, qmax); - activation_scales[m_idx] = scale; - activation_zeros[m_idx] = zero; - torchao::kernels::cpu::aarch64::quantization::quantize( - /*qvals=*/activation_qvals.data() + m_idx * k, - /*vals=*/activations.data() + m_idx * k, - /*size=*/k, - scale, - zero, - qmin, - qmax); - } + auto [activations, activation_qvals, activation_scales, activation_zeros] = + test_utils::generate_per_token_quantized_tensor(m, k); // Generate weights assert(k % weight_group_size == 0); @@ -219,6 +245,8 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { auto weight_scales = std::vector(n_weight_groups, 0.0); auto weight_zeros = std::vector(n_weight_groups, 0); + int qmin, qmax, zero; + float vmin, vmax, scale; // Quantize weights with weight_nbit // TODO: replace with generic function that does not use aarch64 // quantize method after we combine with torchao @@ -322,6 +350,151 @@ struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { } }; +struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { + int m; + int k; + int n; + int stride; + + bool lhs_has_zeros; + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector expected_output; + + std::vector lhs; + std::vector lhs_qvals; + std::vector lhs_scales; + std::vector lhs_zeros; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + int m_, + int k_, + int n_, + int stride_, + bool lhs_has_zeros_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + std::vector expected_output_, + std::vector lhs_, + std::vector lhs_qvals_, + std::vector lhs_scales_, + std::vector lhs_zeros_, + std::vector rhs_, + std::vector rhs_qvals_, + std::vector rhs_scales_, + std::vector rhs_zeros_) + : m(m_), + k(k_), + n(n_), + stride(stride_), + lhs_has_zeros(lhs_has_zeros_), + rhs_has_zeros(rhs_has_zeros_), + lhs_is_transposed(lhs_is_transposed_), + rhs_is_transposed(rhs_is_transposed_), + expected_output(expected_output_), + lhs(lhs_), + lhs_qvals(lhs_qvals_), + lhs_scales(lhs_scales_), + lhs_zeros(lhs_zeros_), + rhs(rhs_), + rhs_qvals(rhs_qvals_), + rhs_scales(rhs_scales_), + rhs_zeros(rhs_zeros_) { + assert(expected_output.size() == m * n); + assert(lhs.size() == m * stride * k); + assert(lhs_qvals.size() == m * stride * k); + assert(lhs_scales.size() == m * stride); + assert(lhs_zeros.size() == m * stride); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == n * stride); + assert(rhs_zeros.size() == n * stride); + } + + static channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case generate( + int m, + int k, + int n, + bool lhs_has_zeros, + bool rhs_has_zeros, + bool lhs_is_transposed, + // rhs_is_transposed means generated b matrix is mxk instead of kxm + bool rhs_is_transposed, + int stride = 1) { + assert(!lhs_is_transposed); + assert(lhs_has_zeros); + assert(rhs_has_zeros); + // !Rhs transposed was considered if we were doing quantized(softmax(q@k)) @ + // quantized(v) Since v would have been [B, H, S, D]. And [S, D] would be + // rhs matrix which is not transposed when considered matmul terminology + // because for matmul we would have A[S_q, S] x B[S, D]. + // It would have been transposed if A[S_q, S] x B[D, S]. + assert(rhs_is_transposed || stride == 1); + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + test_utils::generate_per_token_quantized_tensor(m * stride, k); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + test_utils::generate_per_token_quantized_tensor( + n * stride, k, !rhs_is_transposed); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + // Compute expected output + std::vector expected_output(m * n); + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * stride * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx * stride; + if (rhs_is_transposed) { + rhs_idx = n_idx * stride * k + k_idx; + } + float lhs_dequant = lhs_scales[m_idx * stride] * + (lhs_qvals[lhs_idx] - lhs_zeros[m_idx * stride]); + + float rhs_dequant = rhs_scales[n_idx * stride] * + (rhs_qvals[rhs_idx] - rhs_zeros[n_idx * stride]); + + res += lhs_dequant * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = res; + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + m, + k, + n, + stride, + lhs_has_zeros, + rhs_has_zeros, + lhs_is_transposed, + rhs_is_transposed, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; + template struct lowbit_embedding_test_case { int num_embeddings; diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h new file mode 100644 index 0000000000..52fb0851bc --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h @@ -0,0 +1,403 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include +#include +#include +#include + +namespace torchao { +struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case { + int b; + int s_q; + int s_k; + int h; + int d; + bool tranposed; + + size_t b_q_stride; + size_t h_q_stride; + size_t s_q_stride; + + size_t b_k_stride; + size_t h_k_stride; + size_t s_k_stride; + + size_t b_q_qparams_stride; + size_t h_q_qparams_stride; + size_t s_q_qparams_stride; + + size_t b_k_qparams_stride; + size_t h_k_qparams_stride; + size_t s_k_qparams_stride; + + std::vector expected_output; + + std::vector q; + std::vector q_qvals; + std::vector q_scales; + std::vector q_zeros; + + std::vector k; + std::vector k_qvals; + std::vector k_scales; + std::vector k_zeros; + + channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case( + int b_, + int s_q_, + int s_k_, + int h_, + int d_, + int transposed_, + size_t b_q_stride_, + size_t h_q_stride_, + size_t s_q_stride_, + size_t b_k_stride_, + size_t h_k_stride_, + size_t s_k_stride_, + size_t b_q_qparams_stride_, + size_t h_q_qparams_stride_, + size_t s_q_qparams_stride_, + size_t b_k_qparams_stride_, + size_t h_k_qparams_stride_, + size_t s_k_qparams_stride_, + std::vector expected_output_, + std::vector q_, + std::vector q_qvals_, + std::vector q_scales_, + std::vector q_zeros_, + std::vector k_, + std::vector k_qvals_, + std::vector k_scales_, + std::vector k_zeros_) + : b(b_), + s_q(s_q_), + s_k(s_k_), + h(h_), + d(d_), + tranposed(transposed_), + b_q_stride(b_q_stride_), + h_q_stride(h_q_stride_), + s_q_stride(s_q_stride_), + b_k_stride(b_k_stride_), + h_k_stride(h_k_stride_), + s_k_stride(s_k_stride_), + b_q_qparams_stride(b_q_qparams_stride_), + h_q_qparams_stride(h_q_qparams_stride_), + s_q_qparams_stride(s_q_qparams_stride_), + b_k_qparams_stride(b_k_qparams_stride_), + h_k_qparams_stride(h_k_qparams_stride_), + s_k_qparams_stride(s_k_qparams_stride_), + expected_output(expected_output_), + q(q_), + q_qvals(q_qvals_), + q_scales(q_scales_), + q_zeros(q_zeros_), + k(k_), + k_qvals(k_qvals_), + k_scales(k_scales_), + k_zeros(k_zeros_) { + assert(expected_output.size() == b * s_q * h * s_k); + assert(q.size() == b * s_q * h * d); + assert(q_qvals.size() == b * s_q * h * d); + assert(q_scales.size() == b * s_q * h); + assert(q_zeros.size() == b * s_q * h); + assert(k.size() == b * s_k * h * d); + assert(k_qvals.size() == b * s_k * h * d); + assert(k_scales.size() == b * s_k * h); + assert(k_zeros.size() == b * s_k * h); + } + + static channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case + generate(int b, int s_q, int s_k, int h, int d, bool transposed = true) { + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * s_q * h, d); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * s_k * h, d); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + size_t b_q_stride = h * s_q * d; + size_t h_q_stride = s_q * d; + size_t s_q_stride = d; + + size_t b_k_stride = h * s_k * d; + size_t h_k_stride = s_k * d; + size_t s_k_stride = d; + + size_t b_q_qparams_stride = h * s_q; + size_t h_q_qparams_stride = s_q; + size_t s_q_qparams_stride = 1; + + size_t b_k_qparams_stride = h * s_k; + size_t h_k_qparams_stride = s_k; + size_t s_k_qparams_stride = 1; + + if (!transposed) { + h_q_stride = d; + s_q_stride = h * d; + h_k_stride = d; + s_k_stride = h * d; + + s_q_qparams_stride = h; + h_q_qparams_stride = 1; + + s_k_qparams_stride = h; + h_k_qparams_stride = 1; + } + + // Compute expected output + std::vector expected_output(b * h * s_q * s_k); + size_t b_out_stride = h * s_q * s_k; + size_t h_out_stride = s_q * s_k; + size_t s_q_out_stride = s_k; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int s_q_idx = 0; s_q_idx < s_q; s_q_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int s_k_idx = 0; s_k_idx < s_k; s_k_idx++) { + float res = 0.0; + for (int d_idx = 0; d_idx < d; d_idx++) { + int lhs_idx = b_idx * b_q_stride + s_q_idx * s_q_stride + + h_idx * h_q_stride + d_idx; + int rhs_idx = b_idx * b_k_stride + s_k_idx * s_k_stride + + h_idx * h_k_stride + d_idx; + int lhs_scales_zp_idx = b_idx * b_q_qparams_stride + + h_idx * h_q_qparams_stride + s_q_idx * s_q_qparams_stride; + int rhs_scales_zp_idx = b_idx * b_k_qparams_stride * h + + h_idx * h_k_qparams_stride + s_k_idx * s_k_qparams_stride; + float lhs_dequant = lhs_scales[lhs_scales_zp_idx] * + (lhs_qvals[lhs_idx] - lhs_zeros[lhs_scales_zp_idx]); + + float rhs_dequant = rhs_scales[rhs_scales_zp_idx] * + (rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]); + + res += lhs_dequant * rhs_dequant; + } + expected_output + [b_idx * b_out_stride + s_q_idx * s_q_out_stride + + h_idx * h_out_stride + s_k_idx] = res; + } + } + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case( + b, + s_q, + s_k, + h, + d, + transposed, + b_q_stride, + h_q_stride, + s_q_stride, + b_k_stride, + h_k_stride, + s_k_stride, + b_q_qparams_stride, + h_q_qparams_stride, + s_q_qparams_stride, + b_k_qparams_stride, + h_k_qparams_stride, + s_k_qparams_stride, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; + +struct fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case { + int b; + int s_attn; + int s_v; + int h; + int d; + size_t b_attn_stride; + size_t h_attn_stride; + size_t s_attn_stride; + size_t b_v_stride; + size_t h_v_stride; + size_t s_v_stride; + size_t b_v_qparams_stride; + size_t h_v_qparams_stride; + size_t s_v_qparams_stride; + + std::vector expected_output; + + std::vector attn_scores; + + std::vector v; + std::vector v_qvals; + std::vector v_scales; + std::vector v_zeros; + + fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case( + int b_, + int s_attn_, + int s_v_, + int h_, + int d_, + size_t b_attn_stride_, + size_t h_attn_stride_, + size_t s_attn_stride_, + size_t b_v_stride_, + size_t h_v_stride_, + size_t s_v_stride_, + size_t b_v_qparams_stride_, + size_t h_v_qparams_stride_, + size_t s_v_qparams_stride_, + std::vector expected_output_, + std::vector attn_scores_, + std::vector v_, + std::vector v_qvals_, + std::vector v_scales_, + std::vector v_zeros_) + : b(b_), + s_attn(s_attn_), + s_v(s_v_), + h(h_), + d(d_), + b_attn_stride(b_attn_stride_), + h_attn_stride(h_attn_stride_), + s_attn_stride(s_attn_stride_), + b_v_stride(b_v_stride_), + h_v_stride(h_v_stride_), + s_v_stride(s_v_stride_), + b_v_qparams_stride(b_v_qparams_stride_), + h_v_qparams_stride(h_v_qparams_stride_), + s_v_qparams_stride(s_v_qparams_stride_), + expected_output(expected_output_), + attn_scores(attn_scores_), + v(v_), + v_qvals(v_qvals_), + v_scales(v_scales_), + v_zeros(v_zeros_) { + assert(expected_output.size() == b * s_attn * h * d); + assert(attn_scores.size() == b * h * s_attn * s_v); + assert(v.size() == b * h * s_v * d); + assert(v_qvals.size() == b * h * s_v * d); + assert(v_scales.size() == b * h * s_v); + assert(v_zeros.size() == b * h * s_v); + } + + static fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case + generate(int b, int s_attn, int s_v, int h, int d, bool transposed_v = true) { + // Generate activations + auto lhs = get_random_vector(b * h * s_attn * s_v, -1.0, 1.0); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * h * s_v, d); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + size_t b_attn_stride = h * s_attn * s_v; + size_t h_attn_stride = s_attn * s_v; + size_t s_attn_stride = s_v; + + size_t b_v_stride = h * s_v * d; + size_t h_v_stride = s_v * d; + size_t s_v_stride = d; + + size_t b_v_qparams_stride = h * s_v; + size_t h_v_qparams_stride = s_v; + size_t s_v_qparams_stride = 1; + + if (!transposed_v) { + h_v_stride = d; + s_v_stride = h * d; + + s_v_qparams_stride = h; + h_v_qparams_stride = 1; + } + + // Compute expected output + // Note that while the inputs can be in shape b x h x s_attn x s_v, + // and b x h x s_v x d the output is not in b x h x s_attn x s_v + // but rather b x s_attn x h x d. This is because the output of + // SDPA will normally be in b x h x s_attn x d, but we want to + // avoid any tranposes. Thus just aim to output in b x s_attn x h x d + // This is just for testing purposes. Kernel can actually write output + // in [B, H, S, D] if needed. + std::vector expected_output(b * s_attn * h * d); + size_t b_out_stride = s_attn * h * d; + size_t s_attn_out_stride = h * d; + size_t h_out_stride = d; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int s_attn_idx = 0; s_attn_idx < s_attn; s_attn_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int d_idx = 0; d_idx < d; d_idx++) { + float res = 0.0; + for (int s_v_idx = 0; s_v_idx < s_v; s_v_idx++) { + int lhs_idx = b_idx * b_attn_stride + s_attn_idx * s_attn_stride + + h_idx * h_attn_stride + s_v_idx; + int rhs_idx = b_idx * b_v_stride + h_idx * h_v_stride + d_idx + + s_v_idx * s_v_stride; + int rhs_scales_zp_idx = b_idx * b_v_qparams_stride + + h_idx * h_v_qparams_stride + s_v_idx * s_v_qparams_stride; + float rhs_dequant = rhs_scales[rhs_scales_zp_idx] * + (rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output + [b_idx * b_out_stride + s_attn_idx * s_attn_out_stride + + h_idx * h_out_stride + d_idx] = res; + } + } + } + } + + // Return test case + return fp32_a_channelwise_8bit_b_attn_scores_at_v_test_case( + b, + s_attn, + s_v, + h, + d, + b_attn_stride, + h_attn_stride, + s_attn_stride, + b_v_stride, + h_v_stride, + s_v_stride, + b_v_qparams_stride, + h_v_qparams_stride, + s_v_qparams_stride, + expected_output, + lhs, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; +} // namespace torchao + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h new file mode 100644 index 0000000000..3b070eb2b3 --- /dev/null +++ b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h @@ -0,0 +1,133 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b::internal { + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_tranposed> +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride); +}; + +template +struct KernelImpl { + static void run( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + const int8_t* lhs_qvals = static_cast(lhs); + const int8_t* rhs_qvals = static_cast(rhs); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * lhs_stride_m + k_idx; + int rhs_idx = k_idx * rhs_stride_n + n_idx; + if (b_transposed) { + rhs_idx = n_idx * rhs_stride_n + k_idx; + } + + float lhs_dequant = lhs_scales[m_idx * lhs_qparams_stride] * + (static_cast(lhs_qvals[lhs_idx]) - + static_cast( + lhs_zero_points[m_idx * lhs_qparams_stride])); + + float rhs_dequant = rhs_scales[n_idx * rhs_qparams_stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast( + rhs_zero_points[n_idx * rhs_qparams_stride])); + + res += lhs_dequant * rhs_dequant; + } + output[m_idx * n + n_idx] = res; + } + } + } +}; + +} // namespace + // channelwise_8bit_a_channelwise_8bit_b::internal +} // namespace torchao::kernels::cpu::fallback::quantized_matmul + +// TODO: Remove all ::kernels. No need for extra namespace. +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace channelwise_8bit_a_channelwise_8bit_b { +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +void kernel( + int m, + int n, + int k, + const void* lhs, + int lhs_stride_m, + const void* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* lhs_zero_points, + const int8_t* rhs_zero_points, + const float* lhs_scales, + const float* rhs_scales, + const int lhs_qparams_stride, + const int rhs_qparams_stride) { + channelwise_8bit_a_channelwise_8bit_b::internal:: + KernelImpl::run( + m, + n, + k, + lhs, + lhs_stride_m, + rhs, + rhs_stride_n, + output, + out_stride_m, + lhs_zero_points, + rhs_zero_points, + lhs_scales, + rhs_scales, + lhs_qparams_stride, + rhs_qparams_stride); +} +} // namespace channelwise_8bit_a_channelwise_8bit_b +} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h b/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h new file mode 100644 index 0000000000..58e2853617 --- /dev/null +++ b/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h @@ -0,0 +1,50 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +// TODO: Remove all ::kernels. No need for extra namespace. +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace fp32_a_input_channelwise_8bit_b_fp32 { +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + assert(a_transposed == false); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * lhs_stride_m + k_idx; + int rhs_idx = k_idx * rhs_stride_n + n_idx; + if (b_transposed) { + rhs_idx = n_idx * rhs_stride_n + k_idx; + } + float rhs_dequant = rhs_scales[k_idx * rhs_qparams_stride] * + (static_cast(rhs[rhs_idx]) - + static_cast(rhs_zero_points[k_idx * rhs_qparams_stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + output[m_idx * n + n_idx] = output[m_idx * n + n_idx] * beta + res; + } + } +} +} // namespace fp32_a_input_channelwise_8bit_b_fp32 +} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h new file mode 100644 index 0000000000..718f7eaad9 --- /dev/null +++ b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h @@ -0,0 +1,158 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include +#include + +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#include +#include +#endif // defined(__aarch64__) || defined(__ARM_NEON) + +namespace torchao::kernels::cpu::quantized_matmul { + +/* +a_stride_m: stride of a in memory to indiciate how far apart each row is. +b_stride_n: stride of b in memory to indiciate how far apart each row is. +If b is transposed (n x k), then this is how many bytes to skip to get to the +next row. If b is not transposed (k x n), then this is how many bytes to skip to +get to the next row. + +It also returns the stride of a and b, that should be used in the kernel. + +Will need to think of a better way to find the right +ukernel. Perhaps via ukernelconfig + registry?. +*/ +using int8_a_int8_b_channelwise_fp32_c_qmatmul_type = void (*)( + int, + int, + int, + const void*, + int, + const void*, + int, + float*, + int, + const int8_t*, + const int8_t*, + const float*, + const float*, + const int, + const int); + +int8_a_int8_b_channelwise_fp32_c_qmatmul_type +get_int8_a_int8_b_channelwise_qmatmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n); + +int8_a_int8_b_channelwise_fp32_c_qmatmul_type +get_int8_a_int8_b_channelwise_qmatmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (!a_transposed && b_transposed && n >= 8) { + a_stride_m = k; + b_stride_n = k; + return aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot:: + kernel; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + assert(!a_transposed); + if (b_transposed) { + a_stride_m = k; + b_stride_n = k; + return torchao::kernels::cpu::fallback::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b::kernel; + } else { + return torchao::kernels::cpu::fallback::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b::kernel; + } +} + +/* +a_stride_m: stride of a in memory to indiciate how far apart each row is. +b_stride_n: stride of b in memory to indiciate how far apart each row is. +If b is transposed (n x k), then this is how many bytes to skip to get to the +next row. If b is not transposed (k x n), then this is how many bytes to skip to +get to the next row. + +It also returns the stride of a and b, that should be used in the kernel. + +Will need to think of a better way to find the right +ukernel. Perhaps via ukernelconfig + registry?. +*/ +using fp32_a_input_channelwise_8bit_b_f32_c_matmul_type = void (*)( + int, + int, + int, + const float*, + int, + const int8_t*, + int, + float*, + int, + const int8_t*, + const float*, + const float, + const int); + +fp32_a_input_channelwise_8bit_b_f32_c_matmul_type +get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n); + +fp32_a_input_channelwise_8bit_b_f32_c_matmul_type +get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (!a_transposed && !b_transposed && n >= 16) { + a_stride_m = k; + b_stride_n = n; + return aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32::kernel; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + assert(!a_transposed); + if (b_transposed) { + a_stride_m = k; + b_stride_n = k; + return torchao::kernels::cpu::fallback::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_fp32::kernel; + } else { + a_stride_m = k; + b_stride_n = n; + return torchao::kernels::cpu::fallback::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_fp32::kernel; + } +} +} // namespace torchao::kernels::cpu::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp new file mode 100644 index 0000000000..4024f3f1de --- /dev/null +++ b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp @@ -0,0 +1,630 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +#include +#include + +float kTol = 0.0001; + +// This is unfortunately had to be copied over because code in test_utils.h +// depends on quantization kernels which are only buildable for ARM. +// I would like the testing code in this folder to be independent of the arch. +namespace { +void get_qvals_range(int& qmin, int& qmax, int nbit, bool is_symmetric) { + if (is_symmetric) { + qmin = -(1 << (nbit - 1)) + 1; + qmax = -qmin; + } else { + qmin = -(1 << (nbit - 1)); + qmax = (1 << (nbit - 1)) - 1; + } +} + +void get_scale_and_zero( + float& scale, + int& zero, + float vmin, + float vmax, + int qmin, + int qmax) { + assert(qmin < qmax); + assert(vmin < vmax); + scale = (vmax - vmin) / (qmax - qmin); + zero = qmin - std::round(vmin / scale); +} + +inline std::vector +get_random_vector(int size, float min = -1.0, float max = 1.0) { + assert(min < max); + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_real_distribution(min, max), rng); + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +void quantize( + // Output + int8_t* qvals, + // Inputs + const float* vals, + int size, + float scale, + int8_t zero, + int8_t qmin, + int8_t qmax) { + float invScale = 1.0 / (scale + 1e-16); + int i = 0; + auto curr_rounding_mode = fegetround(); + fesetround(FE_TONEAREST); + for (; i < size; ++i) { + // Quantize remaining elements using scalar code + float val = vals[i]; + float qval_f32 = zero + val * invScale; + int32_t qval_s32 = static_cast(std::nearbyint(qval_f32)); + + // Clip to qmin and qmax + qval_s32 = std::max( + static_cast(qmin), + std::min(qval_s32, static_cast(qmax))); + + // Store the quantized value + qvals[i] = static_cast(qval_s32); + } + fesetround(int(curr_rounding_mode)); +} + +auto generate_per_token_quantized_tensor( + int m, + int n, + bool transposed = false) { + auto activations = get_random_vector(m * n, -1.0, 1.0); + auto activation_qvals = std::vector(m * n, 0); + auto activation_scales = std::vector(m, 0); + auto activation_zeros = std::vector(m, 0); + + // Quantize activations with 8-bit asymmetric + // TODO: replace with generic function that does not use aarch64 + // quantize method after we combine with torchao + int qmin, qmax, zero; + float vmin, vmax, scale; + get_qvals_range(qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); + for (int m_idx = 0; m_idx < m; m_idx++) { + auto minmax = std::minmax_element( + activations.data() + m_idx * n, activations.data() + (m_idx + 1) * n); + vmin = *minmax.first; + vmax = *minmax.second; + get_scale_and_zero(scale, zero, vmin, vmax, qmin, qmax); + activation_scales[m_idx] = scale; + activation_zeros[m_idx] = zero; + quantize( + /*qvals=*/activation_qvals.data() + m_idx * n, + /*vals=*/activations.data() + m_idx * n, + /*size=*/n, + scale, + zero, + qmin, + qmax); + } + + if (transposed) { + auto activations_t = std::vector(m * n, 0); + auto activation_qvals_t = std::vector(m * n, 0); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + int activation_idx = m_idx * n + n_idx; + int tranposed_idx = n_idx * m + m_idx; + activations_t[tranposed_idx] = activations[activation_idx]; + activation_qvals_t[tranposed_idx] = activation_qvals[activation_idx]; + } + } + activations = activations_t; + activation_qvals = activation_qvals_t; + } + + return std::make_tuple( + activations, activation_qvals, activation_scales, activation_zeros); +} + +struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { + int m; + int k; + int n; + int stride; + + bool lhs_has_zeros; + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector expected_output; + + std::vector lhs; + std::vector lhs_qvals; + std::vector lhs_scales; + std::vector lhs_zeros; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + int m_, + int k_, + int n_, + int stride_, + bool lhs_has_zeros_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + std::vector expected_output_, + std::vector lhs_, + std::vector lhs_qvals_, + std::vector lhs_scales_, + std::vector lhs_zeros_, + std::vector rhs_, + std::vector rhs_qvals_, + std::vector rhs_scales_, + std::vector rhs_zeros_) + : m(m_), + k(k_), + n(n_), + stride(stride_), + lhs_has_zeros(lhs_has_zeros_), + rhs_has_zeros(rhs_has_zeros_), + lhs_is_transposed(lhs_is_transposed_), + rhs_is_transposed(rhs_is_transposed_), + expected_output(expected_output_), + lhs(lhs_), + lhs_qvals(lhs_qvals_), + lhs_scales(lhs_scales_), + lhs_zeros(lhs_zeros_), + rhs(rhs_), + rhs_qvals(rhs_qvals_), + rhs_scales(rhs_scales_), + rhs_zeros(rhs_zeros_) { + assert(expected_output.size() == m * n); + assert(lhs.size() == m * stride * k); + assert(lhs_qvals.size() == m * stride * k); + assert(lhs_scales.size() == m * stride); + assert(lhs_zeros.size() == m * stride); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == n * stride); + assert(rhs_zeros.size() == n * stride); + } + + static channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case generate( + int m, + int k, + int n, + bool lhs_has_zeros, + bool rhs_has_zeros, + bool lhs_is_transposed, + // rhs_is_transposed means generated b matrix is mxk instead of kxm + bool rhs_is_transposed, + int stride = 1) { + assert(!lhs_is_transposed); + assert(lhs_has_zeros); + assert(rhs_has_zeros); + assert(rhs_is_transposed || stride == 1); + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + generate_per_token_quantized_tensor(m * stride, k); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + generate_per_token_quantized_tensor(n * stride, k, !rhs_is_transposed); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + // Compute expected output + std::vector expected_output(m * n); + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * stride * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx * stride; + if (rhs_is_transposed) { + rhs_idx = n_idx * stride * k + k_idx; + } + float lhs_dequant = lhs_scales[m_idx * stride] * + (lhs_qvals[lhs_idx] - lhs_zeros[m_idx * stride]); + + float rhs_dequant = rhs_scales[n_idx * stride] * + (rhs_qvals[rhs_idx] - rhs_zeros[n_idx * stride]); + + res += lhs_dequant * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = res; + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( + m, + k, + n, + stride, + lhs_has_zeros, + rhs_has_zeros, + lhs_is_transposed, + rhs_is_transposed, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; +} // namespace + +template < + bool a_has_zeros, + bool b_has_zeros, + bool a_transposed, + bool b_transposed> +struct test_channelwise_8bit_channelwise_8bit_b { + static void Run(int m, int k, int n); +}; + +template +struct test_channelwise_8bit_channelwise_8bit_b< + a_has_zeros, + b_has_zeros, + false, + true> { + static void Run(int m, int k, int n, int stride = 1) { + auto test_case = + channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case::generate( + m, k, n, a_has_zeros, a_has_zeros, false, true, stride); + + int a_stride_m, b_stride_n; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_int8_a_int8_b_channelwise_qmatmul( + m, n, k, false, true, a_stride_m, b_stride_n); + a_stride_m = a_stride_m * stride; + b_stride_n = b_stride_n * stride; + + std::vector output(m * n); + kernel( + m, + n, + k, + test_case.lhs_qvals.data(), + a_stride_m /*lsh_stride_m*/, + test_case.rhs_qvals.data(), + b_stride_n /*rsh_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.lhs_zeros.data(), + test_case.rhs_zeros.data(), + test_case.lhs_scales.data(), + test_case.rhs_scales.data(), + stride, /*lhs qparams stride*/ + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } + } +}; + +TEST(test_channelwise_8bit_channelwise_8bit_b, TranposedBWithZeroPoints) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/1, /*k=*/128, /*n=*/16); +} + +TEST(test_channelwise_8bit_channelwise_8bit_b, TranposeBWithZeroPointsLargeM) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/24); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/5); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback2) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/2, /*n=*/1); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposeBWithZeroPointsLargeMStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/128, /*n=*/16, 5); +} + +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizes2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/19, 16); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallbackStrided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/37, /*n=*/5, 7); +} + +// Test shapes for which we have to use fallback kernel +TEST( + test_channelwise_8bit_channelwise_8bit_b, + TranposedBWithZeroPointsOddSizesFallback2Strided) { + test_channelwise_8bit_channelwise_8bit_b< + true /*a_has_zeros*/, + true /*b_has_zeros*/, + false /*a_transposed*/, + true /*b_transposed*/>:: + Run( + /*m=*/4, /*k=*/2, /*n=*/1, 32); +} + +class FP32A_QuantizedB_FP32C_Interface_Test + : public ::testing::TestWithParam { + public: + int m; + int k; + int n; + int stride; + + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector init_output; + std::vector expected_output; + + std::vector lhs; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + void generate( + int m_, + int k_, + int n_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + int stride_ = 1) { + assert(!lhs_is_transposed_); + assert(rhs_has_zeros_); + m = m_; + k = k_; + n = n_; + stride = stride_; + rhs_has_zeros = rhs_has_zeros_; + lhs_is_transposed = lhs_is_transposed_; + rhs_is_transposed = rhs_is_transposed_; + + assert(!rhs_is_transposed || stride == 1); + + // Generate activations + lhs = get_random_vector(m * k, -1.0, 1.0); + + // The strange thing this is doing is that instead of quantizing + // each output channel separately, we are quantizing each input channel + // Reason why we do !rhs_is_transposed is because + // we actually want k x n matrix not n x k matrix + // because each input channel is quantized separately + std::tie(rhs, rhs_qvals, rhs_scales, rhs_zeros) = + generate_per_token_quantized_tensor(k * stride, n, rhs_is_transposed); + + // Compute expected output + init_output = get_random_vector(m * n, -1.0, 1.0); + + assert(init_output.size() == m * n); + assert(lhs.size() == m * k); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == k * stride); + assert(rhs_zeros.size() == k * stride); + } + + void execute(float beta) { + // Compute expected output + expected_output = init_output; + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx; + if (rhs_is_transposed) { + rhs_idx = n_idx * k * stride + k_idx * stride; + } + float rhs_dequant = rhs_scales[k_idx * stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast(rhs_zeros[k_idx * stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = + expected_output[m_idx * n + n_idx] * beta + res; + } + } + } + + float beta() const { + return GetParam(); + } +}; + +static void test_fp32_a_input_channelwise_8bit_b( + int m, + int k, + int n, + float beta, + FP32A_QuantizedB_FP32C_Interface_Test& test_case, + int stride = 1) { + test_case.execute(beta); + + int a_stride_m, b_stride_n; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + m, n, k, false, false, a_stride_m, b_stride_n); + b_stride_n = b_stride_n * stride; + + std::vector output(test_case.init_output); + kernel( + m, + n, + k, + test_case.lhs.data(), + a_stride_m /*lhs_stride_m*/, + test_case.rhs_qvals.data(), + b_stride_n /*rhs_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.rhs_zeros.data(), + test_case.rhs_scales.data(), + beta, + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST_P(FP32A_QuantizedB_FP32C_Interface_Test, BTranposedWithZeroPoints) { + generate(3, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/3, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes) { + generate(4, 37, 19, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this); +} + +// Test shapes for which we have to use fallback kernel +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizesFallback) { + generate(4, 37, 3, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/3, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes2Fallback) { + generate(4, 1, 3, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/1, /*n=*/3, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizesStrided) { + generate(4, 37, 19, true, false, false, 32); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this, 32); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes2FallbackStrided) { + generate(4, 5, 3, true, false, false, 32); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/5, /*n=*/3, beta(), *this, 32); +} + +INSTANTIATE_TEST_SUITE_P( + F32AInt8BFP32CTest, + FP32A_QuantizedB_FP32C_Interface_Test, + ::testing::Values(0.0, 1.0, 3.1)); diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h index 22b87cfb9e..8113a0566b 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h @@ -253,9 +253,11 @@ Tensor shared_embedding_out_cpu( torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); auto format = torchao::ops::linear_8bit_act_xbit_weight::PackedWeightsFormat:: from_packed_weights_header(header); - torchao::ops::linear_8bit_act_xbit_weight::check_format( + + torchao::ops::linear_8bit_act_xbit_weight::check_format( format, - torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal); + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, + weight_nbit); constexpr int nr = 8; constexpr int kr = 16; constexpr int sr = 2; @@ -316,12 +318,7 @@ Tensor shared_embedding_cpu( const Tensor& indices) { Tensor output_tensor = torch::empty({}, torch::kFloat32); shared_embedding_out_cpu( - packed_weights, - group_size, - n, - k, - indices, - output_tensor); + packed_weights, group_size, n, k, indices, output_tensor); return output_tensor; } #endif // USE_ATEN diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h new file mode 100644 index 0000000000..1e4a9ef670 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h @@ -0,0 +1,238 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include +#include + +namespace torchao::ops::linear_8bit_act_xbit_weight { + +constexpr int kMaxLinearConfigs = 4; +struct UKernelConfig { + // Size of packed_activations buffer + using packed_activations_size_fn_type = size_t (*)( + int m, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr); + + // Offset in packed_activations buffer for a given m_idx + // m_idx is index in unpacked activations matrix; it will be a multiple of + // m_step + using packed_activations_offset_fn_type = size_t (*)( + int m_idx, + int k, + int group_size, + bool has_weight_zeros, + int mr, + int kr, + int sr); + + // Pack activations into packed_activations buffer + using pack_activations_fn_type = void (*)( + void* packed_activations, + int m, + int k, + int group_size, + const float* activations, + bool has_weight_zeros, + int mr, + int kr, + int sr); + + // Size of packed_weights buffer + using packed_weights_size_fn_type = size_t (*)( + int n, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int sr); + + // Offset in packed_weights buffer for a given n_idx + // n_inx is index in unpacked weights matrix; it will be a multiple of n_step + using packed_weights_offset_fn_type = size_t (*)( + int n_idx, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int sr); + + // Pack weights into packed_weights buffer + using pack_weights_fn_type = void (*)( + void* packed_weights, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros, + const float* bias, + int nr, + int kr, + int sr); + + // Run matmul kernel + using kernel_fn_type = void (*)( + float* output, + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* packed_weights, + const void* packed_activations, + float clamp_min, + float clamp_max, + bool has_weight_zeros, + bool has_bias, + bool has_clamp); + + struct linear_config_type { + int m_step{0}; // m_idx will be a multiple of this + int mr{0}; + packed_activations_size_fn_type packed_activations_size{nullptr}; + packed_activations_offset_fn_type packed_activations_offset{nullptr}; + pack_activations_fn_type pack_activations{nullptr}; + kernel_fn_type kernel{nullptr}; + }; + + // preferred_alignment for packed_activations and packed_weights + // Integration surfaces are not required to respect this alignment, and the + // kernel must behave correctly no matter how buffers are aligned + size_t preferred_alignment{0}; + int n_step{0}; // n_idx will be a multiple of this + int nr{0}; + int kr{0}; + int sr{0}; + int weight_nbit{0}; + bool has_weight_zeros{false}; + bool has_bias{false}; + packed_weights_size_fn_type packed_weights_size{nullptr}; + packed_weights_offset_fn_type packed_weights_offset{nullptr}; + pack_weights_fn_type pack_weights{nullptr}; + + // linear_configs must be sorted in ascending m_step + std::array linear_configs; + + static UKernelConfig make( + size_t preferred_alignment, + int n_step, + int nr, + int kr, + int sr, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + packed_weights_size_fn_type packed_weights_size, + packed_weights_offset_fn_type packed_weights_offset, + pack_weights_fn_type pack_weights, + std::array linear_configs); + + inline void validate() const { + TORCHAO_CHECK(preferred_alignment >= 1, "preferred_alignment must be >= 1"); + TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1"); + TORCHAO_CHECK(nr >= 1, "nr must be >= 1"); + TORCHAO_CHECK(kr >= 1, "kr must be >= 1"); + TORCHAO_CHECK(sr >= 1, "sr must be >= 1"); + TORCHAO_CHECK(weight_nbit >= 1, "weight_nbit must be >= 1"); + TORCHAO_CHECK( + packed_weights_size != nullptr, "packed_weights_size must be set"); + TORCHAO_CHECK( + packed_weights_offset != nullptr, "packed_weights_offset must be set"); + TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set"); + + bool linear_configs_set = true; // first linear config must be set + for (int i = 0; i < linear_configs.size(); i++) { + if (linear_configs_set) { + TORCHAO_CHECK( + linear_configs[i].m_step >= 1, + "linear_configs[i].m_step must be >= 1"); + TORCHAO_CHECK( + linear_configs[i].mr >= 1, "linear_configs[i].mr must be >= 1"); + TORCHAO_CHECK( + linear_configs[i].packed_activations_size != nullptr, + "linear_configs[i].packed_activations_size must be set"); + TORCHAO_CHECK( + linear_configs[i].packed_activations_offset != nullptr, + "linear_configs[i].packed_activations_offset must be set"); + TORCHAO_CHECK( + linear_configs[i].pack_activations != nullptr, + "linear_configs[i].pack_activations must be set"); + TORCHAO_CHECK( + linear_configs[i].kernel != nullptr, + "linear_configs[i].kernel must be set"); + if (i >= 1) { + TORCHAO_CHECK( + linear_configs[i - 1].m_step < linear_configs[i].m_step, + "set linear_configs must be increasing in m_step"); + } + if (i + 1 < linear_configs.size()) { + linear_configs_set = (linear_configs[i + 1].m_step >= 1); + } + } + } + } + + inline int select_linear_config_idx(int m) const { + assert(m >= 1); + assert(linear_configs[0].m_step >= 1); + + int i = 0; + while (i + 1 < linear_configs.size() && linear_configs[i + 1].m_step >= 1 && + linear_configs[i + 1].m_step <= m) { + assert(linear_configs[i].m_step < linear_configs[i + 1].m_step); + i++; + } + + assert(i < linear_configs.size()); + assert(linear_configs[i].m_step >= 1); + assert(i == 0 || linear_configs[i].m_step <= m); + return i; + } +}; + +inline UKernelConfig UKernelConfig::make( + size_t preferred_alignment, + int n_step, + int nr, + int kr, + int sr, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + packed_weights_size_fn_type packed_weights_size, + packed_weights_offset_fn_type packed_weights_offset, + pack_weights_fn_type pack_weights, + std::array linear_configs) { + return UKernelConfig{ + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + has_weight_zeros, + has_bias, + packed_weights_size, + packed_weights_offset, + pack_weights, + std::move(linear_configs)}; +} + +} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index 17d7ec13b1..ffdd62f7a7 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -6,20 +6,21 @@ #pragma once #include -#include +#include #include - -#if defined(TORCHAO_BUILD_CPU_AARCH64) -#include -#endif // TORCHAO_BUILD_CPU_AARCH64 - #include #include #include +#if defined(TORCHAO_BUILD_CPU_AARCH64) +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) +#include +#endif // TORCHAO_ENABLE_ARM_NEON_DOT + #if defined(TORCHAO_ENABLE_KLEIDI) #include #endif // TORCHAO_ENABLE_KLEIDI +#endif // TORCHAO_BUILD_CPU_AARCH64 namespace torchao::ops::linear_8bit_act_xbit_weight { @@ -50,6 +51,7 @@ struct UKernelConfigRegistrationTable { throw std::runtime_error( "UKernelConfig is already registered for this format"); } + config.validate(); registration_table_[key] = config; } std::optional get_ukernel_config( @@ -89,73 +91,96 @@ void register_ukernel_config_universal( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - check_format( + + check_format( format, - torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal); + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, + weight_nbit); + + namespace kernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + + constexpr bool has_lut = false; + int preferred_alignment = 16; if (format.nr == 8 && format.kr == 16 && format.sr == 2) { -#if defined(TORCHAO_BUILD_CPU_AARCH64) + constexpr int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + constexpr int mr = 1; + constexpr int m_step = 1; + +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) if (cpuinfo_has_arm_neon_dot()) { - log_registration(format, "universal"); - namespace kernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ 16, - /*nr*/ 8, - /*weight_packing_config*/ - {/*weight_data_size_fn*/ - &kernel::weight_data_size, - /*prepare_weight_data_fn*/ - &kernel::prepare_weight_data}, - /*linear_configs*/ - {{{/*mr*/ 1, - /*activation_data_size_fn*/ - &kernel::activation_data_size, - /*prepare_activation_data_fn*/ - &kernel::prepare_activation_data, - /*kernel*/ - &kernel::kernel}}}}); - return; + log_registration(format, "universal: kernel_1x8x16_f32_neondot"); + auto uk = UKernelConfig::make( + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + format.has_weight_zeros, + format.has_bias, + &kernel::packed_weights_size, + &kernel::packed_weights_offset, + &kernel::pack_weights, + /*linear_configs*/ {}); + + if (format.has_weight_zeros) { + constexpr bool has_weight_zeros = true; + uk.linear_configs[0] = UKernelConfig::linear_config_type( + {m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel::kernel_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_lut>}); + + table.register_ukernel_config(format, uarch, std::move(uk)); + return; + } else { + constexpr bool has_weight_zeros = false; + uk.linear_configs[0] = UKernelConfig::linear_config_type( + {m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel::kernel_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_lut>}); + + table.register_ukernel_config(format, uarch, std::move(uk)); + return; + } } -#endif // TORCHAO_BUILD_CPU_AARCH64 +#endif // TORCHAO_ENABLE_ARM_NEON_DOT } } #if defined(TORCHAO_ENABLE_KLEIDI) -template < - typename kernel_struct, - int m_step, - int mr, - int n_step, - int nr, - int kr, - int sr> -UKernelConfig::linear_config_type get_linear_config_kleidi() { +template +UKernelConfig::linear_config_type +get_linear_config_kleidi(int n_step, int nr, int kr, int sr) { namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; - assert(m_step == kernel_struct::get_ukernel().get_m_step()); - assert(mr == kernel_struct::get_ukernel().get_mr()); assert(n_step == kernel_struct::get_ukernel().get_n_step()); assert(nr == kernel_struct::get_ukernel().get_nr()); assert(kr == kernel_struct::get_ukernel().get_kr()); assert(sr == kernel_struct::get_ukernel().get_sr()); - return UKernelConfig::linear_config_type{ - /*mr*/ m_step, - /*activation_data_size_fn*/ &op::activation_data_size, - /*prepare_activation_data_fn*/ &op::prepare_activation_data, - /*kernel*/ &kernel_struct::kernel}; -} - -template -UKernelConfig::weight_packing_config_type get_weight_packing_config_kleidi() { - namespace op = torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p; - return UKernelConfig::weight_packing_config_type( - {/*weight_data_size_fn*/ &op::weight_data_size, - /*prepare_weight_data_fn*/ &op::prepare_weight_data}); + return UKernelConfig::linear_config_type( + {static_cast(kernel_struct::get_ukernel().get_m_step()), + static_cast(kernel_struct::get_ukernel().get_mr()), + &op::packed_activations_size, + &op::packed_activations_offset, + &op::pack_activations, + &kernel_struct::kernel}); } template @@ -166,95 +191,79 @@ void register_ukernel_config_kleidi( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - check_format(format, torchao::ops::PackedWeightsType::kleidi_ai); + check_format(format, torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit); namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + auto uk = UKernelConfig::make( + /*preferred_alignment*/ op::get_preferred_alignement(), + /*n_step*/ format.nr, + format.nr, + format.kr, + format.sr, + format.weight_nbit, + format.has_weight_zeros, + format.has_bias, + &op::packed_weights_size, + &op::packed_weights_offset, + &op::pack_weights, + {} /*linear_configs*/); + if (format.nr == 8 && format.kr == 16 && format.sr == 2) { - constexpr int nr = 8; - constexpr int kr = 16; - constexpr int sr = 2; + uk.n_step = 8; + #if defined(TORCHAO_ENABLE_ARM_I8MM) if (cpuinfo_has_arm_i8mm()) { - constexpr int n_step = 8; log_registration( format, - "kleidiai: matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"); - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ op::get_preferred_alignement(), - /*nr*/ n_step, - /*weight_packing_config*/ - get_weight_packing_config_kleidi(), - /*linear_configs*/ - {{get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - /*m_step*/ 4, - /*mr*/ 4, - n_step, - nr, - kr, - sr>()}}}); + "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm"); + /*m_step=1*/ + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + + /*m_step=4*/ + uk.linear_configs[1] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm>( + uk.n_step, uk.nr, uk.kr, uk.sr); + table.register_ukernel_config(format, uarch, std::move(uk)); return; } #endif // TORCHAO_ENABLE_ARM_I8MM +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) if (cpuinfo_has_arm_neon_dot()) { - constexpr int n_step = 8; log_registration( format, "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod"); - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ op::get_preferred_alignement(), - /*nr*/ n_step, - /*weight_packing_config*/ - get_weight_packing_config_kleidi(), - /*linear_configs*/ - {{get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - /*m_step*/ 1, - /*mr*/ 1, - n_step, - nr, - kr, - sr>()}}}); + /*m_step=1*/ + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + table.register_ukernel_config(format, uarch, std::move(uk)); return; } +#endif // TORCHAO_ENABLE_ARM_NEON_DOT } - if (format.nr == 4 && format.kr == 16 && format.sr == 2) { - constexpr int nr = 4; - constexpr int kr = 16; - constexpr int sr = 2; + if (format.nr == 8 && format.kr == 8 && format.sr == 2) { +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) if (cpuinfo_has_arm_neon_dot()) { - constexpr int n_step = 4; log_registration( format, - "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod"); - table.register_ukernel_config( - format, - uarch, - UKernelConfig{ - /*preferred_alignment*/ op::get_preferred_alignement(), - /*nr*/ n_step, - /*weight_packing_config*/ - get_weight_packing_config_kleidi(), - /*linear_configs*/ - {{get_linear_config_kleidi< - op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /*m_step*/ 1, - /*mr*/ 1, - n_step, - nr, - kr, - sr>()}}}); + "kleidiai: matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod, matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod"); + // m_step 1 + uk.linear_configs[0] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + // m_step 4 + uk.linear_configs[1] = get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod>( + uk.n_step, uk.nr, uk.kr, uk.sr); + table.register_ukernel_config(format, uarch, std::move(uk)); return; } +#endif // TORCHAO_ENABLE_ARM_NEON_DOT } } #endif // TORCHAO_ENABLE_KLEIDI @@ -328,22 +337,32 @@ PackedWeightsFormat select_packed_weights_format( #if defined(TORCHAO_ENABLE_KLEIDI) if (!target || *target == "kleidiai") { if (weight_nbit == 4 && (!has_weight_zeros)) { - // KleidiAI will pack bias with weights always, - // even if bias is not provided 0s will be packed +#if defined(TORCHAO_ENABLE_ARM_I8MM) return PackedWeightsFormat( torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit, has_weight_zeros, - /*has_bias*/ true, + has_bias, /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); +#elif defined(TORCHAO_ENABLE_ARM_NEON_DOT) + return PackedWeightsFormat( + torchao::ops::PackedWeightsType::kleidi_ai, + weight_nbit, + has_weight_zeros, + has_bias, + /*nr*/ 8, + /*kr*/ 8, + /*sr*/ 2); +#endif } } #endif // defined(TORCHAO_ENABLE_KLEIDI) // Select universal format if (!target || *target == "universal") { +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) return PackedWeightsFormat( torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, weight_nbit, @@ -352,6 +371,7 @@ PackedWeightsFormat select_packed_weights_format( /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); +#endif // defined(TORCHAO_ENABLE_ARM_NEON_DOT) } throw std::runtime_error("No packed_weights_format was selected"); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index 0421e6a25f..6929e6e4a4 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -7,43 +7,19 @@ #include #include #include +#include #include #include #include #include +#include namespace torchao::ops::linear_8bit_act_xbit_weight { -PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( - const UKernelConfig& ukernel_config, - int n, - int target_panels_per_thread) { - TORCHAO_CHECK(n >= 1, "n must be >= 1"); - TORCHAO_CHECK( - target_panels_per_thread >= 1, "target_panels_per_thread must be >= 1"); - - PackWeightDataTilingParams tiling_params; - int nr = ukernel_config.nr; - int num_threads = torchao::get_num_threads(); - int numerator = n; - int denominator = num_threads * target_panels_per_thread; - - // Set nc = ceil(numerator / denominator) - int nc = (numerator + denominator - 1) / denominator; - assert(nc >= 1); - - // Replace nc with the next number nr divides - nc = ((nc + nr - 1) / nr) * nr; - tiling_params.nc_by_nr = nc / nr; - - return tiling_params; -} - -void pack_weight_data_operator( - const UKernelConfig& ukernel_config, - const PackWeightDataTilingParams& tiling_params, +void pack_weights_operator( + const UKernelConfig& uk, // Outputs - void* weight_data, + void* packed_weights, // Inputs int n, int k, @@ -54,12 +30,14 @@ void pack_weight_data_operator( const float* bias) { TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); + TORCHAO_CHECK( + uk.has_bias == (bias != nullptr), "bias/has_bias is inconsistent"); + TORCHAO_CHECK( + uk.has_weight_zeros == (weight_zeros != nullptr), + "weight_zeros/has_weight_zeros is inconsistent"); - bool has_weight_zeros = (weight_zeros != nullptr); - bool has_bias = (bias != nullptr); - - int nr = ukernel_config.nr; - int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); + int n_step = uk.n_step; + int nc = std::min(n, n_step); int num_nc_panels = (n + nc - 1) / nc; torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { @@ -67,50 +45,53 @@ void pack_weight_data_operator( int n_idx = nc_tile_idx * nc; int nc_tile_size = std::min(nc, n - n_idx); - int weight_data_offset = (n_idx / nr) * - ukernel_config.weight_packing_config.weight_data_size_fn( - nr, k, group_size, has_weight_zeros, has_bias); + auto packed_weights_offset = uk.packed_weights_offset( + n_idx, + k, + group_size, + uk.weight_nbit, + uk.has_weight_zeros, + uk.has_bias, + uk.nr, + uk.kr, + uk.sr); + int weight_qvals_offset = n_idx * k; int weight_scales_and_zeros_offset = (n_idx * k / group_size); - - const int8_t* weight_zeros_ptr = nullptr; - if (weight_zeros != nullptr) { - weight_zeros_ptr = weight_zeros + weight_scales_and_zeros_offset; - } - const float* bias_ptr = nullptr; - if (bias != nullptr) { - bias_ptr = bias + n_idx; - } - - ukernel_config.weight_packing_config.prepare_weight_data_fn( - (char*)weight_data + weight_data_offset, + uk.pack_weights( + (char*)packed_weights + packed_weights_offset, /*n=*/nc_tile_size, k, group_size, weight_qvals + weight_qvals_offset, weight_scales + weight_scales_and_zeros_offset, - weight_zeros_ptr, - bias_ptr); + (weight_zeros == nullptr) + ? nullptr + : (weight_zeros + weight_scales_and_zeros_offset), + (bias == nullptr) ? nullptr : (bias + n_idx), + uk.nr, + uk.kr, + uk.sr); }); } -// This default mimics XNNPACK behavior if target_tiles_per_thread = 5 -LinearTilingParams get_default_linear_tiling_params( - const UKernelConfig& ukernel_config, +LinearTilingParams LinearTilingParams::from_target_tiles_per_thread( int m, + int m_step, int n, + int n_step, int target_tiles_per_thread) { TORCHAO_CHECK(m >= 1, "m must be >= 1"); + TORCHAO_CHECK(m_step >= 1, "m_step must be >= 1"); + TORCHAO_CHECK(n >= 1, "n must be >= 1"); + TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1"); TORCHAO_CHECK( target_tiles_per_thread >= 1, "target_tiles_per_thread must be >= 1"); - - LinearTilingParams tiling_params; auto num_threads = torchao::get_num_threads(); TORCHAO_CHECK(num_threads >= 1, "num_threads must be >= 1"); - tiling_params.mc_by_mr = 1; - int mc = tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr; + int mc = m_step; int num_mc_panels = (m + mc - 1) / mc; int numerator = n * num_mc_panels; @@ -120,50 +101,25 @@ LinearTilingParams get_default_linear_tiling_params( int nc = (numerator + denominator - 1) / denominator; assert(nc >= 1); - // Replace nc with next number nr divides - int nr = ukernel_config.nr; - nc = ((nc + nr - 1) / nr) * nr; - assert(nc % nr == 0); - tiling_params.nc_by_nr = nc / nr; + // Replace nc with next number n_step divides + nc = ((nc + n_step - 1) / n_step) * n_step; - assert(tiling_params.mc_by_mr >= 1); - assert(tiling_params.nc_by_nr >= 1); - return tiling_params; -} - -namespace internal { + // Clamp mc, nc to be no larger than m, n + mc = std::min(m, mc); + nc = std::min(n, nc); -inline size_t -get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - int m, - int k, - int group_size, - bool has_weight_zeros) { - return ukernel_config.linear_configs[0].activation_data_size_fn( - tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr, - k, - group_size, - has_weight_zeros); -} + assert((mc == m) || (mc % m_step == 0)); + assert((nc == n) || (nc % n_step == 0)); -inline size_t -get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - int m, - int k, - int group_size, - bool has_weight_zeros) { - return ukernel_config.linear_configs[0].activation_data_size_fn( - m, k, group_size, has_weight_zeros); + LinearTilingParams tiling_params; + tiling_params.mc = mc; + tiling_params.nc = nc; + return tiling_params; } -inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - char* activation_data_buffer, +void linear_operator( + const UKernelConfig& uk, + const std::optional& tiling_params, // Outputs float* output, // Inputs @@ -171,237 +127,101 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( int n, int k, int group_size, - const void* weight_data, + const void* packed_weights, const float* activations, - // Ignored if has_clamp = false + bool has_clamp, float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - int nr = ukernel_config.nr; - int mc = - std::min(m, tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr); - int nc = std::min(n, tiling_params.nc_by_nr * nr); + float clamp_max) { + TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); + TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); + + // Select linear config based on m + int linear_config_idx = uk.select_linear_config_idx(m); + auto& linear_config = uk.linear_configs[linear_config_idx]; + int n_step = uk.n_step; + int m_step = linear_config.m_step; + + // Choose tiling params + int mc, nc; + if (tiling_params.has_value()) { + mc = tiling_params->mc; + nc = tiling_params->nc; + } else { + auto params = LinearTilingParams::from_target_tiles_per_thread( + m, + m_step, + n, + n_step, + /*target_tiles_per_thread=*/5); + mc = params.mc; + nc = params.nc; + } + TORCHAO_CHECK(mc >= 1, "mc must be >= 1"); + TORCHAO_CHECK(nc >= 1, "nc must be >= 1"); + TORCHAO_CHECK( + (mc == m) || (mc % m_step == 0), + "mc from tiling_params must be m or a multiple of m_step"); + TORCHAO_CHECK( + (nc == n) || (nc % n_step == 0), + "nc from tiling_params must be n or a multiple of n_step"); + int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; - size_t weight_data_size = - ukernel_config.weight_packing_config.weight_data_size_fn( - nr, k, group_size, has_weight_zeros, has_bias); + auto packed_activations_size = linear_config.packed_activations_size( + mc, k, group_size, uk.has_weight_zeros, linear_config.mr, uk.kr, uk.sr); + + auto packed_activations = torchao::make_aligned_byte_ptr( + uk.preferred_alignment, packed_activations_size); for (int mc_tile_idx = 0; mc_tile_idx < num_mc_panels; mc_tile_idx++) { int m_idx = mc_tile_idx * mc; int mc_tile_size = std::min(mc, m - m_idx); int activations_offset = m_idx * k; - ukernel_config.linear_configs[0].prepare_activation_data_fn( - activation_data_buffer, + + linear_config.pack_activations( + packed_activations.get(), /*m=*/mc_tile_size, k, group_size, activations + activations_offset, - has_weight_zeros); + uk.has_weight_zeros, + linear_config.mr, + uk.kr, + uk.sr); torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { int nc_tile_idx = idx; int n_idx = nc_tile_idx * nc; int nc_tile_size = std::min(nc, n - n_idx); - int output_offset = m_idx * n + n_idx; - int weight_data_offset = (n_idx / nr) * weight_data_size; - ukernel_config.linear_configs[0].kernel_fn( + auto packed_weights_offset = uk.packed_weights_offset( + n_idx, + k, + group_size, + uk.weight_nbit, + uk.has_weight_zeros, + uk.has_bias, + uk.nr, + uk.kr, + uk.sr); + + linear_config.kernel( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, /*n=*/nc_tile_size, k, group_size, - /*weight_data=*/(char*)weight_data + weight_data_offset, - /*activation_data=*/activation_data_buffer, + /*packed_weights=*/(char*)packed_weights + packed_weights_offset, + /*packed_activations=*/packed_activations.get(), clamp_min, clamp_max, - has_weight_zeros, - has_bias, + uk.has_weight_zeros, + uk.has_bias, has_clamp); }); } } -inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - char* activation_data_buffer, - // Outputs - float* output, - // Inputs - int m, - int n, - int k, - int group_size, - const void* weight_data, - const float* activations, - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - int mr = ukernel_config.linear_configs[0].mr; - int nr = ukernel_config.nr; - int mc = std::min(m, tiling_params.mc_by_mr * mr); - int nc = std::min(n, tiling_params.nc_by_nr * nr); - int num_mc_panels = (m + mc - 1) / mc; - int num_nc_panels = (n + nc - 1) / nc; - - size_t weight_data_size = - ukernel_config.weight_packing_config.weight_data_size_fn( - nr, k, group_size, has_weight_zeros, has_bias); - size_t activation_data_size = - ukernel_config.linear_configs[0].activation_data_size_fn( - mr, k, group_size, has_weight_zeros); - - torchao::parallel_1d(0, num_mc_panels, [&](int64_t idx) { - int mc_tile_idx = idx; - int m_idx = mc_tile_idx * mc; - int mc_tile_size = std::min(mc, m - m_idx); - int activations_offset = m_idx * k; - int activation_data_offset = (m_idx / mr) * activation_data_size; - - ukernel_config.linear_configs[0].prepare_activation_data_fn( - activation_data_buffer + activation_data_offset, - /*m=*/mc_tile_size, - k, - group_size, - activations + activations_offset, - has_weight_zeros); - }); - - torchao::parallel_1d(0, num_mc_panels * num_nc_panels, [&](int64_t idx) { - int mc_tile_idx = idx / num_nc_panels; - int m_idx = mc_tile_idx * mc; - int mc_tile_size = std::min(mc, m - m_idx); - - int nc_tile_idx = idx % num_nc_panels; - int n_idx = nc_tile_idx * nc; - int nc_tile_size = std::min(nc, n - n_idx); - - int activation_data_offset = (m_idx / mr) * activation_data_size; - int output_offset = m_idx * n + n_idx; - int weight_data_offset = (n_idx / nr) * weight_data_size; - - ukernel_config.linear_configs[0].kernel_fn( - output + output_offset, - /*output_m_stride=*/n, - /*m=*/mc_tile_size, - /*n=*/nc_tile_size, - k, - group_size, - /*weight_data=*/(char*)weight_data + weight_data_offset, - /*activation_data=*/activation_data_buffer + activation_data_offset, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - }); -} -} // namespace internal - -void linear_operator( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - char* activation_data_buffer, - // Outputs - float* output, - // Inputs - int m, - int n, - int k, - int group_size, - const void* weight_data, - const float* activations, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp) { - TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); - TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); - switch (scheduling_policy) { - case LinearTileSchedulingPolicy::single_mc_parallel_nc: - internal::linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( - ukernel_config, - tiling_params, - activation_data_buffer, - output, - m, - n, - k, - group_size, - weight_data, - activations, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - break; - case LinearTileSchedulingPolicy::parallel_mc_parallel_nc: - internal:: - linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( - ukernel_config, - tiling_params, - activation_data_buffer, - output, - m, - n, - k, - group_size, - weight_data, - activations, - clamp_min, - clamp_max, - has_weight_zeros, - has_bias, - has_clamp); - break; - default: - TORCHAO_CHECK(false, "Unimplemented LinearTileSchedulingPolicy"); - } -} - -size_t get_activation_data_buffer_size( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - int m, - int k, - int group_size, - bool has_weight_zeros) { - switch (scheduling_policy) { - case LinearTileSchedulingPolicy::single_mc_parallel_nc: - return internal:: - get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( - ukernel_config, - tiling_params, - m, - k, - group_size, - has_weight_zeros); - case LinearTileSchedulingPolicy::parallel_mc_parallel_nc: - return internal:: - get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( - ukernel_config, - tiling_params, - m, - k, - group_size, - has_weight_zeros); - default: - TORCHAO_CHECK(false, "Unimplemented LinearTileSchedulingPolicy"); - } -} - -} // namespace - // torchao::ops::linear_8bit_act_xbit_weight +} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index dba0adb32d..accc5be5a1 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -7,102 +7,17 @@ #pragma once #include #include +#include #include #include +#include namespace torchao::ops::linear_8bit_act_xbit_weight { -struct UKernelConfig { - using activation_data_size_fn_type = - size_t (*)(int m, int k, int group_size, bool has_weight_zeros); - using prepare_activation_data_fn_type = void (*)( - void* activation_data, - int m, - int k, - int group_size, - const float* activations, - bool has_weight_zeros); - using weight_data_size_fn_type = size_t (*)( - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias); - using prepare_weight_data_fn_type = void (*)( - void* weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias); - using kernel_fn_type = void (*)( - float* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp); - - struct weight_packing_config_type { - weight_data_size_fn_type weight_data_size_fn{nullptr}; - prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; - }; - struct linear_config_type { - int mr{0}; - activation_data_size_fn_type activation_data_size_fn{nullptr}; - prepare_activation_data_fn_type prepare_activation_data_fn{nullptr}; - kernel_fn_type kernel_fn{nullptr}; - }; - - // preferred_alignment for activation and weight data - // Integration surfaces are not required to respect this alignment, and the - // ukernel must behave correctly no matter how buffers are aligned - size_t preferred_alignment{0}; - int nr{0}; - weight_packing_config_type weight_packing_config; - std::array linear_configs; -}; - -// Pack weight functions -struct PackWeightDataTilingParams { - int nc_by_nr{1}; -}; - -PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( - const UKernelConfig& ukernel_config, - int n, - int target_panels_per_thread = 1); - -inline size_t get_packed_weight_data_size( - const UKernelConfig& ukernel_config, - int n, - int k, - int group_size, - bool has_weight_zeros, - bool has_bias) { - return ukernel_config.weight_packing_config.weight_data_size_fn( - n, k, group_size, has_weight_zeros, has_bias); -} - -inline size_t get_preferred_packed_weight_data_alignment( - const UKernelConfig& ukernel_config) { - return ukernel_config.preferred_alignment; -} - -void pack_weight_data_operator( - const UKernelConfig& ukernel_config, - const PackWeightDataTilingParams& tiling_params, +void pack_weights_operator( + const UKernelConfig& uk, // Outputs - void* weight_data, + void* packed_weights, // Inputs int n, int k, @@ -114,40 +29,23 @@ void pack_weight_data_operator( // Linear functions struct LinearTilingParams { - int mc_by_mr{1}; - int nc_by_nr{1}; -}; - -LinearTilingParams get_default_linear_tiling_params( - const UKernelConfig& ukernel_config, - int m, - int n, - int target_tiles_per_thread = 5); + int mc{0}; + int nc{0}; -enum class LinearTileSchedulingPolicy { - single_mc_parallel_nc, - parallel_mc_parallel_nc + // Returns LinearTilingParams with mc and nc chosen so that there are + // approximately target_tiles_per_thread tiles per thread. The method + // guarantees 1. mc = m or mc % m_step == 0, and 2. nc = n or nc % n_step == 0 + static LinearTilingParams from_target_tiles_per_thread( + int m, + int m_step, + int n, + int n_step, + int target_tiles_per_thread); }; -size_t get_activation_data_buffer_size( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - int m, - int k, - int group_size, - bool has_weight_zeros); - -inline size_t get_preferred_activation_data_buffer_alignment( - const UKernelConfig& ukernel_config) { - return ukernel_config.preferred_alignment; -} - void linear_operator( const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - char* activation_data_buffer, + const std::optional& tiling_params, // Outputs float* output, // Inputs @@ -155,13 +53,11 @@ void linear_operator( int n, int k, int group_size, - const void* weight_data, + const void* packed_weights, const float* activations, + bool has_clamp, float clamp_min, - float clamp_max, - bool has_weight_zeros, - bool has_bias, - bool has_clamp); + float clamp_max); } // namespace // torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 636fc01c64..065a5b0319 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -69,29 +69,31 @@ Tensor pack_weights_cpu( bias_ptr = bias.value().const_data_ptr(); } - using namespace torchao::ops::linear_8bit_act_xbit_weight; - - auto packed_weights_format = select_packed_weights_format( - target, has_weight_zeros, has_bias); + auto packed_weights_format = + torchao::ops::linear_8bit_act_xbit_weight::select_packed_weights_format< + weight_nbit>(target, has_weight_zeros, has_bias); auto packed_weights_header = packed_weights_format.to_packed_weights_header(); - auto ukernel_config = - select_ukernel_config(packed_weights_header); - - auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( - ukernel_config, n, /*target_panels_per_thread=*/1); + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(packed_weights_header); + + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + uk.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + uk.nr, + uk.kr, + uk.sr); - auto packed_weight_data_size = - torchao::ops::PackedWeightsHeader::size() + - get_packed_weight_data_size( - ukernel_config, n, k, group_size, has_weight_zeros, has_bias); Tensor packed_weights = torch::empty( {static_cast(packed_weight_data_size)}, torch::kInt8); packed_weights_header.write(packed_weights.mutable_data_ptr()); - // TODO: support passing in bias in future - pack_weight_data_operator( - ukernel_config, - pack_weight_tiling_params, + torchao::ops::linear_8bit_act_xbit_weight::pack_weights_operator( + uk, packed_weights.mutable_data_ptr() + torchao::ops::PackedWeightsHeader::size(), n, @@ -122,18 +124,26 @@ Tensor pack_weights_meta( int n = weight_qvals.size(0); int k = weight_qvals.size(1); - using namespace torchao::ops::linear_8bit_act_xbit_weight; - - auto packed_weights_format = select_packed_weights_format( - target, has_weight_zeros, has_bias); - auto ukernel_config = - select_ukernel_config(packed_weights_format); - - auto packed_weight_data_size = - torchao::ops::PackedWeightsHeader::size() + - get_packed_weight_data_size( - ukernel_config, n, k, group_size, has_weight_zeros, has_bias); - auto options = torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); + auto packed_weights_format = + torchao::ops::linear_8bit_act_xbit_weight::select_packed_weights_format< + weight_nbit>(target, has_weight_zeros, has_bias); + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(packed_weights_format); + + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + uk.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + uk.nr, + uk.kr, + uk.sr); + + auto options = + torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); return torch::empty({static_cast(packed_weight_data_size)}, options); } #endif // USE_ATEN @@ -169,8 +179,6 @@ Tensor linear_out_cpu( // Explicit cast from int64_t to int is required for Executorch TORCHAO_RESIZE_TENSOR(out, {(int)m, (int)n}); - using namespace torchao::ops::linear_8bit_act_xbit_weight; - TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); #ifdef USE_ATEN TORCHAO_CHECK( @@ -182,36 +190,12 @@ Tensor linear_out_cpu( auto header = torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); - auto format = torchao::ops::linear_8bit_act_xbit_weight::PackedWeightsFormat:: - from_packed_weights_header(header); - - auto ukernel_config = select_ukernel_config(header); - - auto linear_tiling_params = get_default_linear_tiling_params( - ukernel_config, - m, - n, - /*target_tiles_per_thread=*/5); - - auto linear_scheduling_policy = - LinearTileSchedulingPolicy::single_mc_parallel_nc; - - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - m, - k, - group_size, - format.has_weight_zeros); - - std::vector activation_data_buffer(activation_data_buffer_size); + auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< + weight_nbit>(header); - linear_operator( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - activation_data_buffer.data(), + torchao::ops::linear_8bit_act_xbit_weight::linear_operator( + uk, + std::nullopt, out.mutable_data_ptr(), m, n, @@ -220,13 +204,9 @@ Tensor linear_out_cpu( packed_weights.const_data_ptr() + torchao::ops::PackedWeightsHeader::size(), activations.const_data_ptr(), - // Clamp parameters are ignored because config is created from - // has_clamp = false + /*has_clamp=*/false, /*clamp_min=*/0.0, - /*clamp_max=*/0.0, - format.has_weight_zeros, - format.has_bias, - /*has_clamp*/ false); + /*clamp_max=*/0.0); return out; } diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h index 82beea43fb..e22082f9f1 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h @@ -53,10 +53,10 @@ struct PackedWeightsFormat { } }; -template -void check_format( +inline void check_format( PackedWeightsFormat format, - torchao::ops::PackedWeightsType type) { + torchao::ops::PackedWeightsType type, + int weight_nbit) { if (format.type != type) { throw std::runtime_error( "Kernel expects packed_weights type=" + diff --git a/torchao/experimental/ops/tests/CMakeLists.txt b/torchao/experimental/ops/tests/CMakeLists.txt index 8a9ad08f23..8245fdd746 100644 --- a/torchao/experimental/ops/tests/CMakeLists.txt +++ b/torchao/experimental/ops/tests/CMakeLists.txt @@ -24,6 +24,7 @@ enable_testing() if(TORCHAO_BUILD_CPU_AARCH64) add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64=1) + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) endif() if(TORCHAO_BUILD_KLEIDIAI) diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index ae11b56e42..1d4127a43e 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -6,7 +6,9 @@ #include // TODO: move test_utils.h out of aarch64 -#include +#if defined(TORCHAO_BUILD_CPU_AARCH64) +#include +#endif // TORCHAO_BUILD_CPU_AARCH64 #include #include #include @@ -19,30 +21,48 @@ using namespace torchao::kernels::cpu::aarch64::kleidi:: #endif // TORCHAO_ENABLE_KLEIDI const float kTol = 1.0e-5; -const float kTolKleidiAI = 1.0e-2; +const float kTolKleidiAI = 5.0e-2; using namespace torchao::ops::linear_8bit_act_xbit_weight; template UKernelConfig get_ukernel_config() { namespace kernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - return UKernelConfig{ - /*preferred_alignment*/ 16, - /*nr*/ 8, - /*weight_packing_config*/ - {/*weight_data_size_fn*/ - &kernel::weight_data_size, - /*prepare_weight_data_fn*/ - &kernel::prepare_weight_data}, - /*linear_configs*/ - {{{/*mr*/ 1, - /*activation_data_size_fn*/ - &kernel::activation_data_size, - /*prepare_activation_data_fn*/ - &kernel::prepare_activation_data, - /*kernel*/ - &kernel::kernel}}}}; + channelwise_8bit_activation_groupwise_lowbit_weight; + + int preferred_alignment = 16; + int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + constexpr int mr = 1; + int m_step = 1; + constexpr bool has_lut = false; + + auto uk = UKernelConfig::make( + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + has_weight_zeros, + has_bias, + &kernel::packed_weights_size, + &kernel::packed_weights_offset, + &kernel::pack_weights, + /*linear_configs*/ {}); + + uk.linear_configs[0] = UKernelConfig::linear_config_type{ + m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel:: + kernel_1x8x16_f32_neondot}; + + return uk; } template < @@ -82,87 +102,68 @@ void test_linear_8bit_act_xbit_weight( auto output = std::vector(m * n); - for (auto linear_scheduling_policy : - {LinearTileSchedulingPolicy::single_mc_parallel_nc, - LinearTileSchedulingPolicy::parallel_mc_parallel_nc}) { - for (auto num_threads : {1, 4, 500}) { - torchao::set_num_threads(num_threads); - EXPECT_EQ(torchao::get_num_threads(), num_threads); - - // Pack weights - auto pack_weight_data_tiling_params = - get_default_pack_weight_data_tiling_params(ukernel_config, n); - auto packed_weight_data_size = get_packed_weight_data_size( - ukernel_config, n, k, group_size, has_weight_zeros, has_bias); - auto preferred_packed_weight_data_alignment = - get_preferred_packed_weight_data_alignment(ukernel_config); - auto packed_weight_data = torchao::make_aligned_byte_ptr( - preferred_packed_weight_data_alignment, packed_weight_data_size); - - int8_t* weight_zeros_ptr = nullptr; - if (has_weight_zeros) { - weight_zeros_ptr = test_case.weight_zeros.data(); - } - float* bias_ptr = nullptr; - if (has_bias) { - bias_ptr = test_case.bias.data(); - } - pack_weight_data_operator( - ukernel_config, - pack_weight_data_tiling_params, - packed_weight_data.get(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - weight_zeros_ptr, - bias_ptr); - - // Allocate activation buffer - auto linear_tiling_params = - get_default_linear_tiling_params(ukernel_config, m, n); - - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - m, - k, - group_size, - has_weight_zeros); - auto activation_data_buffer_alignment = - get_preferred_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_ptr( - activation_data_buffer_alignment, activation_data_buffer_size); - - // Run linear - linear_operator( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - activation_data_buffer.get(), - output.data(), - m, - n, - k, - group_size, - packed_weight_data.get(), - test_case.activations.data(), - test_case.clamp_min, - test_case.clamp_max, - has_weight_zeros, - has_bias, - has_clamp); + for (auto num_threads : {1, 4, 500}) { + torchao::set_num_threads(num_threads); + EXPECT_EQ(torchao::get_num_threads(), num_threads); - // Test correctness - float tol = kTol; - if (has_kleidi) { - tol = kTolKleidiAI; - } - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], tol); - } + // Pack weights + auto packed_weight_data_size = ukernel_config.packed_weights_size( + n, + k, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + ukernel_config.nr, + ukernel_config.kr, + ukernel_config.sr); + auto preferred_packed_weight_data_alignment = + ukernel_config.preferred_alignment; + auto packed_weights = torchao::make_aligned_byte_ptr( + preferred_packed_weight_data_alignment, packed_weight_data_size); + + int8_t* weight_zeros_ptr = nullptr; + if (has_weight_zeros) { + weight_zeros_ptr = test_case.weight_zeros.data(); + } + float* bias_ptr = nullptr; + // kleidi always has bias in these tests + if (has_bias || has_kleidi) { + bias_ptr = test_case.bias.data(); + } + + pack_weights_operator( + ukernel_config, + packed_weights.get(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + weight_zeros_ptr, + bias_ptr); + + linear_operator( + ukernel_config, + std::nullopt, + output.data(), + m, + n, + k, + group_size, + packed_weights.get(), + test_case.activations.data(), + has_clamp, + test_case.clamp_min, + test_case.clamp_max); + + // Test correctness + float tol = kTol; + if (has_kleidi) { + tol = kTolKleidiAI; + } + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], tol); } } } @@ -176,102 +177,136 @@ enum kai_kernel_id { i8mm_8x4x32 }; -template < - typename kernel_struct, - int m_step, - int mr, - int n_step, - int nr, - int kr, - int sr> -UKernelConfig get_ukernel_config_kleidi() { +template +UKernelConfig get_ukernel_config_kleidi_impl() { namespace op = torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + auto uk = kernel_struct::get_ukernel(); - assert(m_step == uk.get_m_step()); - assert(mr == uk.get_mr()); - assert(n_step == uk.get_n_step()); - assert(nr == uk.get_nr()); - assert(kr == uk.get_kr()); - assert(sr == uk.get_sr()); - return UKernelConfig{ + auto ukernel_config = UKernelConfig::make( op::get_preferred_alignement(), - n_step, - {/*weight_data_size_fn*/ &op::weight_data_size, - /*prepare_weight_data_fn*/ &op::prepare_weight_data}, - {{{m_step, - &op::activation_data_size, - &op::prepare_activation_data, - &kernel_struct::kernel}}}}; + uk.get_n_step(), + uk.get_nr(), + uk.get_kr(), + uk.get_sr(), + /*weight_nbit*/ 4, + /*has_weight_zeros*/ false, + /*has_bias*/ true, + &op::packed_weights_size, + &op::packed_weights_offset, + &op::pack_weights, + /*linear_configs*/ {}); + + ukernel_config.linear_configs[0] = UKernelConfig::linear_config_type{ + static_cast(uk.get_m_step()), + static_cast(uk.get_mr()), + &op::packed_activations_size, + &op::packed_activations_offset, + &op::pack_activations, + &kernel_struct::kernel}; + + return ukernel_config; +} + +template +void test_linear_8bit_act_xbit_weight_kleidiai() { + constexpr int weight_nbit = 4; + constexpr bool has_kleidi = true; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + auto uk = get_ukernel_config_kleidi_impl(); + + for (auto m : {1, 3, 4, 8, 9, 13, 21, 43, 101}) { + for (auto n : + {1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 4 * 13, + 4 * 13 + 3, + 8 * 13, + 8 * 13 + 3, + 16 * 13, + 16 * 13 + 3}) { + for (auto k : {32, 64, 128}) { + int group_size = 32; + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + /*has_clamp*/ true, + has_kleidi>(m, n, k, group_size, &uk); + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + /*has_clamp*/ false, + has_kleidi>(m, n, k, group_size, &uk); + + if (k >= 64) { + group_size = 64; + test_linear_8bit_act_xbit_weight< + weight_nbit, + has_weight_zeros, + has_bias, + /*has_clamp*/ true, + has_kleidi>(m, n, k, group_size, &uk); + } + } + } + } +} + +#if defined(TORCHAO_ENABLE_ARM_NEON_DOT) +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod>(); +} +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>(); } +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod>(); +} +TEST( + test_linear_8bit_act_xbit_weight, + matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod) { + test_linear_8bit_act_xbit_weight_kleidiai< + matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod>(); +} +#endif // TORCHAO_ENABLE_ARM_NEON_DOT template UKernelConfig get_ukernel_config_kleidi() { #if defined(TORCHAO_ENABLE_ARM_I8MM) if constexpr (kernel_id == i8mm_4x8x32) { - constexpr int m_step = 4; - constexpr int mr = 4; - constexpr int n_step = 8; - constexpr int nr = 8; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm>(); } if constexpr (kernel_id == i8mm_8x4x32) { - constexpr int m_step = 8; - constexpr int mr = 8; - constexpr int n_step = 4; - constexpr int nr = 4; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm>(); } #endif // TORCHAO_ENABLE_ARM_I8MM if constexpr (kernel_id == dotprod_1x8x32) { - constexpr int m_step = 1; - constexpr int mr = 1; - constexpr int n_step = 8; - constexpr int nr = 8; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod>(); } if constexpr (kernel_id == dotprod_1x4x32) { - constexpr int m_step = 1; - constexpr int mr = 1; - constexpr int n_step = 4; - constexpr int nr = 4; - constexpr int kr = 16; - constexpr int sr = 2; - return get_ukernel_config_kleidi< - matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - m_step, - mr, - n_step, - nr, - kr, - sr>(); + return get_ukernel_config_kleidi_impl< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod>(); } throw std::runtime_error("Unsupported kernel_id"); } @@ -332,15 +367,11 @@ TEST(test_linear_8bit_act_xbit_weight, KNotDivisibleByGroupSize) { true /*has_weight_zeros*/, true /*has_bias*/, true /*has_clamp*/>(); - auto pack_weight_data_tiling_params = - get_default_pack_weight_data_tiling_params(ukernel_config, n); - EXPECT_THROW( { - pack_weight_data_operator( + pack_weights_operator( ukernel_config, - pack_weight_data_tiling_params, - /*packed_weight_data=*/nullptr, + /*packed_weights=*/nullptr, n, k, group_size, @@ -362,15 +393,12 @@ TEST(test_linear_8bit_act_xbit_weight, GroupSizeNotDivisibleBy16) { true /*has_weight_zeros*/, true /*has_bias*/, true /*has_clamp*/>(); - auto pack_weight_data_tiling_params = - get_default_pack_weight_data_tiling_params(ukernel_config, n); EXPECT_THROW( { - pack_weight_data_operator( + pack_weights_operator( ukernel_config, - pack_weight_data_tiling_params, - /*packed_weight_data=*/nullptr, + /*packed_weights=*/nullptr, n, k, group_size, diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 098fc09696..dcd8eb74d5 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -105,6 +105,58 @@ def test_accuracy(self, layout, weight_dtype, has_weight_zeros, granularity): expected_result = quantized_model_reference(activations) self._assert_close(result, expected_result) + def test_accuracy_kleidiai(self): + n = 1071 + k = 2048 + model = torch.nn.Sequential( + *[torch.nn.Linear(k, k, bias=False), torch.nn.Linear(k, n, bias=True)] + ) + weight_dtype = torch.int4 + granularity = PerGroup(128) + has_weight_zeros = False + + # We set round_weight_scale_to_bf16 to True for accuracy testing because + # some KleidiAI kernels do this internally + round_weight_scale_to_bf16 = True + + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout( + target="kleidiai" + ), + round_weight_scale_to_bf16=round_weight_scale_to_bf16, + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=self._reference_layout(), + round_weight_scale_to_bf16=round_weight_scale_to_bf16, + ), + ) + + with torch.no_grad(): + for m in [1, 3, 5, 9, 13]: + activations = torch.randn(m, k) + result = quantized_model(activations) + expected_result = quantized_model_reference(activations) + + # KleidiAI kernels require much higher tolerance when comparing to reference, + # especially for GEMM kernels + self._assert_close( + result, expected_result, mse_tol=1e-2, atol=1e-2, rtol=1 + ) + def test_accuracy_aten(self): m = 3 n = 1024 @@ -151,9 +203,21 @@ def test_accuracy_aten(self): self._assert_close(result, expected_result) - def _assert_close(self, result, expected_result): - self.assertTrue(torch.nn.functional.mse_loss(result, expected_result) <= 1e-6) - self.assertTrue(torch.allclose(result, expected_result, atol=1e-2)) + def _assert_close( + self, result, expected_result, mse_tol=1e-6, atol=1e-2, rtol=1e-5 + ): + mse_loss = torch.nn.functional.mse_loss(result, expected_result) + self.assertTrue( + mse_loss <= mse_tol, + f"Got mse_loss={mse_loss}, above mse tolerance {mse_tol}", + ) + + n_rand_idxs = 5 + rand_idxs = torch.randint(0, result.numel(), (n_rand_idxs,)) + self.assertTrue( + torch.allclose(result, expected_result, atol=atol, rtol=rtol), + f"Failed allclose at atol={atol}, rtol={rtol}. On {n_rand_idxs} random indices, we have result={result.reshape(-1)[rand_idxs]} vs expected_result={expected_result.reshape(-1)[rand_idxs]}.", + ) def _reference_layout(self): return PlainLayout() diff --git a/torchao/ops.py b/torchao/ops.py index 34a97d03f5..5bc71321ac 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -71,6 +71,13 @@ def decorator(func): return decorator +@functools.lru_cache +def cached_compute_capability(): + device_props = torch.cuda.get_device_properties(torch.cuda.current_device()) + compute_capability = device_props.major * 10 + device_props.minor + return compute_capability + + def quant_llm_linear( EXPONENT: int, MANTISSA: int, @@ -93,6 +100,12 @@ def quant_llm_linear( Returns output of linear layer """ + # Check if we're on a supported architecture (sm7.5 or higher) + compute_capability = cached_compute_capability() + torch._check( + compute_capability >= 75, + lambda: f"quant_llm_linear requires sm7.5+ GPU architecture, but current device has sm{compute_capability}", + ) return torch.ops.torchao.quant_llm_linear.default( EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK ) diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear_test.py b/torchao/prototype/float8nocompile/float8nocompile_linear_test.py index f62569cbb4..7df5ce768c 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear_test.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear_test.py @@ -7,7 +7,7 @@ import torch from torchao.float8.config import Float8LinearConfig -from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp +from torchao.float8.float8_linear import matmul_with_hp_or_float8_args from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig from torchao.prototype.float8nocompile.float8nocompile_linear import ( matmul_with_args_in_hp, @@ -72,7 +72,7 @@ def test_matmul_with_args_in_hp(input_shape: tuple[int, int]): ) # prod forward. expects transposed weight. - out_prod = manual_float8_matmul_with_args_in_hp.apply( + out_prod = matmul_with_hp_or_float8_args.apply( prod_input_bf16, prod_weight_bf16.t(), linear_mm_config, config ) diff --git a/torchao/prototype/scaled_grouped_mm/__init__.py b/torchao/prototype/scaled_grouped_mm/__init__.py new file mode 100644 index 0000000000..9c6278884a --- /dev/null +++ b/torchao/prototype/scaled_grouped_mm/__init__.py @@ -0,0 +1,3 @@ +from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import ( + _scaled_grouped_mm as _scaled_grouped_mm, +) diff --git a/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py b/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py new file mode 100644 index 0000000000..a431288c07 --- /dev/null +++ b/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py @@ -0,0 +1,361 @@ +from typing import Optional, Tuple + +import torch + +from torchao.float8.config import ScalingGranularity +from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated + + +def _scaled_grouped_mm( + A: torch.Tensor, + B_t: torch.Tensor, + offs: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + This function performs dynamic float8 quantization with row-wise scaling + on the input tensors A and B, then performs a scaled grouped GEMM and returns the results. + + Args: + A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K) + and in row-major memory layout. + B_t (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N) + and in column-major memory layout. + offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. + out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. + """ + return _Float8GroupedMM.apply( + A, + B_t, + offs, + out_dtype, + ) + + +class _Float8GroupedMM(torch.autograd.Function): + """Differentiable implementation of grouped GEMM with dynamic float8 quantization.""" + + @staticmethod + def forward( + ctx, + A: torch.Tensor, + B_t: torch.Tensor, + offs: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + # torchao _scaled_grouped_mm only supports A=2D, B=3D. + assert A.ndim == 2, "A must be 2D" + assert B_t.ndim == 3, "B must be 3D" + + assert ( + A.size(-1) % 16 == 0 + ), f"A must have a last dim divisible by 16, but got shape: {A.shape}" + assert ( + B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0 + ), f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}" + + # Assert input tensors are in high-precision dtypes. + assert ( + A.dtype == torch.float32 or A.dtype == torch.bfloat16 + ), "A must be float32 or bfloat16" + assert ( + B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16 + ), "B must be float32 or bfloat16" + assert offs.dtype == torch.int32, "offs must be int32" + + # Assert A and B dims are compatible for a scaled grouped GEMM. + assert A.size(-1) == B_t.size( + -2 + ), f"shape {A.shape} and {B_t.shape} are not compatible for _scaled_grouped_mm" + + # The left operand in the scaled grouped GEMM must be row-major due to hardware requirements. + assert not _is_column_major(A), "A must be row-major" + + # Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major. + assert _is_column_major(B_t), "B must be column-major" + + # Convert high precision input tensor to float8, row-major for left operand of grouped GEMM. + # A shape: (M, K) + # A_scales shape: (M,1) + A_scales = tensor_to_scale( + A, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=True, + ) + A_scaled = A.to(torch.float32) * A_scales + A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) + + # Convert B to float8, column-major for right operand of grouped GEMM. + # B shape: (B, K, N) + # B scales must be computed rowwise keeping the outer/final dim, so: + # B_scales shape: (B, 1, N) + B_t_scales = tensor_to_scale( + B_t, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-2, + round_scales_to_power_of_2=True, + ) + B_t_scaled = B_t.to(torch.float32) * B_t_scales + B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) + + # Precompute non-transposed B column-major for backward, to save memory by storing the + # low precision B tensor instead of the high precision B tensor. + # In the backward this is needed for grad_A: grad_output @ B. + B = B_t.contiguous().transpose(-2, -1) + + # - B shape: (B, K, N) + # - B scales must be computed rowwise keeping the outer/final dim, so: + # - B_scale shape: (B, 1, N) + B_scales = tensor_to_scale( + B, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-2, + round_scales_to_power_of_2=True, + ) + B_scaled = B.to(torch.float32) * B_scales + B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn) + + # Store what we need for backward. + ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs) + ctx.out_dtype = out_dtype + + # Perform scaled grouped GEMM and return result. + # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N) + return torch._scaled_grouped_mm( + A_fp8_row_major, + B_t_fp8_col_major, + A_scales.squeeze().reciprocal(), + B_t_scales.squeeze().reciprocal(), + offs, + out_dtype=out_dtype, + use_fast_accum=True, + ) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors + out_dtype = ctx.out_dtype + + # Convert grad_output to float8, row-major for left operand of grouped GEMM + # needed for grad_A: grad_output @ B + # + # grad_output shape: (M, N) + # grad_output_scale shape: (M, 1) + grad_output_scales = tensor_to_scale( + grad_output, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=True, + ) + grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales + grad_output_fp8_row_major = to_fp8_saturated( + grad_output_scaled, torch.float8_e4m3fn + ) + + # Compute grad_A. + # + # grad_A = grad_output @ B + # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) + grad_A = torch._scaled_grouped_mm( + grad_output_fp8_row_major, + B_fp8_col_major, + grad_output_scales.squeeze().reciprocal(), + B_scales.squeeze().reciprocal(), + offs, + out_dtype=out_dtype, + use_fast_accum=True, + ) + + # Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM + # needed for grad_B: grad_output_t @ A + grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous() + + # Convert A to float8, column-major for right operand of grouped GEMM: + # needed for grad_B: grad_output @ A + A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1) + + # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups." + # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups. + grad_output_t_fp8_row_major, grad_output_t_scales = ( + _to_2d_jagged_float8_tensor_rowwise( + grad_output_t_row_major, + offs, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + ) + A_fp8_col_major, A_scales = _to_2d_jagged_float8_tensor_colwise( + A_col_major, + offs, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + + # Compute grad_B = grad_output_t @ A. + # grad_B = grad_output_t @ A + # grad_B = (N,M) @ (M,K) = (N,K) + grad_B = torch._scaled_grouped_mm( + grad_output_t_fp8_row_major, + A_fp8_col_major, + grad_output_t_scales.reciprocal(), + A_scales.reciprocal(), + offs, + out_dtype=out_dtype, + use_fast_accum=True, + ) + return grad_A, grad_B.transpose(-2, -1), None, None, None, None + + +def _to_2d_jagged_float8_tensor_colwise( + A_col_major: torch.Tensor, + offs: torch.Tensor, + target_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function converts the 2D input tensor A to a jagged float8 tensor, + with scales computed along *logical columns* for each group individually, + where groups are determined based on the offsets. + + For the right operand of a normal scaled GEMM, the rowwise scales are computed over logical columns. + (i.e., a tensor of (K,N) will have scales of shape (1,N). + + However, for a 2D right operand of a grouped GEMM, these logical columns go through multiple distinct + groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales + along the logical columns and apply it to the entire tensor. + + Instead, we need to compute scales for each subtensor individually. For a tensor of shape (K,N) this results + in scales of shape (1,N * num_groups). + + Args: + A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor. + + Returns: + A tuple containing the jagged float8 tensor and the scales used for the conversion. + """ + assert A_col_major.ndim == 2, "A must be 2D" + + num_groups = offs.numel() + A_fp8_col_major = torch.empty_like(A_col_major, dtype=target_dtype) + A_scales = torch.empty( + A_fp8_col_major.size(1) * num_groups, + dtype=torch.float32, + device=A_fp8_col_major.device, + ) + + start_idx = 0 + next_scale_idx = 0 + for end_idx in offs.tolist(): + # Get the subtensor of A for this group, fetching the next group of rows, with all columns for each. + subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, K) + + # Compute local rowwise scales for this subtensor, which are along logical columns for the right operand. + subtensor_scales = tensor_to_scale( + subtensor, + target_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=0, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + + # Apply scales to subtensor and convert to float8. + tensor_scaled = subtensor.to(torch.float32) * subtensor_scales + float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype) + + # Store this portion of the resulting float8 tensor and scales. + A_fp8_col_major[start_idx:end_idx, :] = float8_subtensor + A_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = ( + subtensor_scales.squeeze() + ) + + # Update start index for next group. + start_idx = end_idx + next_scale_idx += subtensor_scales.numel() + + return A_fp8_col_major, A_scales + + +def _to_2d_jagged_float8_tensor_rowwise( + x: torch.Tensor, + offs: torch.Tensor, + target_dtype: torch.dtype, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function converts the 2D input tensor to a jagged float8 tensor, + with scales computed along *logical rows* for each group individually, + where groups are determined based on the offsets. + + For a 2D *left* operand of a normal scaled GEMM, the rowwise scales are computed over logical rows. + (i.e., a tensor of (M,K) will have scales of shape (M,1). + + However, for a 2D left operand of a grouped GEMM, these logical rows go through multiple distinct + groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales + along the logical rows and apply it to the entire tensor. + + Instead, we need to compute scales for each subtensor individually. For a tensor of shape (M,K) this results + in scales of shape (M * num_groups, 1). + + Args: + A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor. + + Returns: + A tuple containing the jagged float8 tensor and the scales used for the conversion. + """ + assert x.ndim == 2, "input tensor must be 2D" + + num_groups = offs.numel() + x_fp8 = torch.empty_like(x, dtype=target_dtype) + x_scales = torch.empty( + x_fp8.size(0) * num_groups, dtype=torch.float32, device=x_fp8.device + ) + + start_idx = 0 + next_scale_idx = 0 + for end_idx in offs.tolist(): + # Get the subtensor of A for this group, fetching all rows with the next group of rows. + subtensor = x[:, start_idx:end_idx] # (M, local_group_size) + + # Compute local rowwise scales for this subtensor, which are along logical rows for the left operand. + subtensor_scales = tensor_to_scale( + subtensor, + target_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + + # Apply scales to subtensor and convert to float8. + tensor_scaled = subtensor.to(torch.float32) * subtensor_scales + float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype) + + # Store this portion of the resulting float8 tensor and scales. + x_fp8[:, start_idx:end_idx] = float8_subtensor + x_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = ( + subtensor_scales.squeeze() + ) + + # Update start index for next group. + start_idx = end_idx + next_scale_idx += subtensor_scales.numel() + + return x_fp8, x_scales + + +def _is_column_major(x: torch.Tensor) -> bool: + """ + This function checks if the input tensor is column-major. + + Args: + x (torch.Tensor): The input tensor to be checked. + + Returns: + A boolean indicating whether the input tensor is column-major. + """ + assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D" + return x.stride(-2) == 1 and x.stride(-1) > 1 diff --git a/torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py b/torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py new file mode 100644 index 0000000000..cd347c3d9d --- /dev/null +++ b/torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py @@ -0,0 +1,196 @@ +import pytest +import torch + +from torchao.float8.config import ( + Float8LinearConfig, + Float8LinearRecipeName, +) +from torchao.float8.float8_linear import matmul_with_hp_or_float8_args +from torchao.float8.float8_tensor import LinearMMConfig +from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated +from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import ( + _scaled_grouped_mm, +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_valid_scaled_grouped_mm_2d_3d(): + out_dtype = torch.bfloat16 + device = "cuda" + m, n, k, n_groups = 16, 32, 16, 4 + a = torch.randn( + m * n_groups, + k, + device=device, + requires_grad=True, + dtype=torch.bfloat16, + ) + b = torch.randn( + n_groups, + n, + k, + device=device, + dtype=torch.bfloat16, + ) + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + + # b must be transposed and in column major format. + b_t = b.contiguous().transpose(-2, -1).requires_grad_(True) + + # Compute output. + out = _scaled_grouped_mm( + a, + b_t, + offs=offs, + out_dtype=out_dtype, + ) + + # Validate result. + ref_a = a.detach().clone().requires_grad_(True) + ref_b_t = b_t.detach().clone().requires_grad_(True) + ref_out = compute_reference_forward( + out, + ref_a, + ref_b_t, + n_groups, + out_dtype, + offs, + ) + assert torch.equal(out, ref_out) + + # Run backward pass. + out.sum().backward() + ref_out.sum().backward() + + # Validate gradients. + assert torch.equal(a.grad, ref_a.grad) + assert torch.equal(b_t.grad, ref_b_t.grad) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("m", [16, 17]) +@pytest.mark.parametrize("k", [16, 18]) +@pytest.mark.parametrize("n", [32, 33]) +def test_K_or_N_dim_not_multiple_of_16(m, n, k): + # - Leading dim of A doesn't have to be divisible by 16, since it will be + # divided up into groups based on offset anyway. + # - Trailing dim of A must be divisible by 16. + # - Leading dim of B (n_groups) doesn't need to be divisible by 16. + # - Last 2 dims of B must be divisible by 16. + if n % 16 == 0 and k % 16 == 0: + return + out_dtype = torch.bfloat16 + device = "cuda" + n_groups = 4 + a = torch.randn( + m * n_groups, + k, + device=device, + requires_grad=True, + dtype=torch.bfloat16, + ) + b = torch.randn( + n_groups, + n, + k, + device=device, + requires_grad=True, + dtype=torch.bfloat16, + ) + + # b must be transposed and in column major format. + b_t = b.transpose(-2, -1) + b_t = b_t.transpose(-2, -1).contiguous().transpose(-2, -1) + + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + + # Compute output. + with pytest.raises(AssertionError): + _scaled_grouped_mm( + a, + b_t, + offs=offs, + out_dtype=out_dtype, + ) + + +def compute_reference_forward( + result: torch.Tensor, + A: torch.Tensor, + B_t: torch.Tensor, + n_groups: int, + out_dtype: torch.dtype, + offs: torch.Tensor, +): + assert result.dtype == out_dtype + + # Use official rowwise recipe as reference to ensure implementation is correct. + float8_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) + + # Convert A to fp8. + A_scales = tensor_to_scale( + A, + float8_config.cast_config_input.target_dtype, + scaling_granularity=float8_config.cast_config_input.scaling_granularity, + axiswise_dim=-1, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, + ) + A_scaled = A.to(torch.float32) * A_scales + A_fp8 = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) + + # Convert B^t to fp8. + B_t_scales = tensor_to_scale( + B_t, + float8_config.cast_config_weight.target_dtype, + scaling_granularity=float8_config.cast_config_weight.scaling_granularity, + axiswise_dim=-2, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, + ) + B_t_scaled = B_t.to(torch.float32) * B_t_scales + B_t_fp8 = to_fp8_saturated( + B_t_scaled, + torch.float8_e4m3fn, + ) + + # Split A and result into chunks, one for each group. + offs_cpu = offs.cpu() + A_list, A_list_fp8, A_scale_list, result_list = [], [], [], [] + start = 0 + for i in range(n_groups): + A_list.append(A[start : offs_cpu[i]]) + A_list_fp8.append(A_fp8[start : offs_cpu[i]]) + A_scale_list.append(A_scales[start : offs_cpu[i]]) + result_list.append(result[start : offs_cpu[i]]) + start = offs_cpu[i] + + # Validate each actual result group from the _scaled_grouped_mm is equal to: + # 1. A manual _scaled_mm for the group. + # 2. A matmul_with_hp_or_float8_args for the group (which is differentiable, and thus used to validate gradients). + outputs = [] + list1 = list(zip(A_list_fp8, B_t_fp8, A_scale_list, B_t_scales, result_list)) + list2 = list(zip(A_list, B_t, result_list)) + for i in range(len(list1)): + a1, b1, a1scale, b1scale, result1 = list1[i] + ref_group_result1 = torch._scaled_mm( + a1, + b1, + a1scale.reciprocal(), + b1scale.reciprocal(), + out_dtype=out_dtype, + bias=None, + use_fast_accum=float8_config.gemm_config_output.use_fast_accum, + ) + a2, b2, result2 = list2[i] + ref_group_result2 = matmul_with_hp_or_float8_args.apply( + a2, + b2, + LinearMMConfig(), + float8_config, + ) + assert torch.equal(result1, ref_group_result1) + assert torch.equal(result2, ref_group_result2) + outputs.append(ref_group_result2) + + # Concatenate the outputs and verify the full result is correct. + output_ref = torch.cat(outputs, dim=0) + return output_ref diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 6c63937051..63b1da440d 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -24,7 +24,10 @@ find_multiple, ) -from .quant_primitives import MappingType +from .quant_primitives import ( + MappingType, + dequantize_affine, +) from .unified import Quantizer from .utils import ( _MultiInput, @@ -940,19 +943,17 @@ def linear_forward_8da4w( n_bit = 4 quant_min = -(2 ** (n_bit - 1)) quant_max = 2 ** (n_bit - 1) - 1 - from torchao._executorch_ops import ( - _quantized_decomposed_dequantize_per_channel_group_wrapper, - ) + block_size = (1, groupsize) - w_dq = _quantized_decomposed_dequantize_per_channel_group_wrapper( + w_dq = dequantize_affine( weight_int8, + block_size, scales, zeros, + torch.int8, quant_min, quant_max, - torch.int8, - groupsize, - precision, + output_dtype=precision, ) # x = x.to(torch.float16) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 74c136ad00..b23f39c6d7 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -539,36 +539,48 @@ def group_quantize_tensor_symmetric( return w_int8, scales, zeros -def per_token_dynamic_quant(input: torch.Tensor) -> torch.Tensor: - orig_dtype = input.dtype - # TODO: we may need to make the choose_qparams op configurable - from torchao._executorch_ops import ( - _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper, - ) - - ( - scales, - zero_points, - ) = _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper( - input, torch.int8 - ) - - # TODO: get these from torch.int8 +def per_token_dynamic_quant( + input: torch.Tensor, + scale_dtype: torch.dtype = torch.float32, + zero_point_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + mapping_type = MappingType.ASYMMETRIC + block_size = _get_per_token_block_size(input) quant_min = -128 quant_max = 127 - from torchao._executorch_ops import _quantized_decomposed_quantize_per_token_wrapper + quant_dtype = torch.int8 + output_dtype = input.dtype - input = _quantized_decomposed_quantize_per_token_wrapper( - input, scales, zero_points, quant_min, quant_max, torch.int8 + scales, zero_points = choose_qparams_affine( + input, + mapping_type, + block_size, + quant_dtype, + quant_min, + quant_max, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, ) - from torchao._executorch_ops import ( - _quantized_decomposed_dequantize_per_token_wrapper, + q = quantize_affine( + input, + block_size, + scales, + zero_points, + quant_dtype, + quant_min, + quant_max, ) - - input = _quantized_decomposed_dequantize_per_token_wrapper( - input, scales, zero_points, quant_min, quant_max, torch.int8, orig_dtype + dq = dequantize_affine( + q, + block_size, + scales, + zero_points, + quant_dtype, + quant_min, + quant_max, + output_dtype=output_dtype, ) - return input.to(orig_dtype) + return dq def recommended_inductor_config_setter(): diff --git a/version.txt b/version.txt index 78bc1abd14..d9df1bbc0c 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.10.0 +0.11.0