Skip to content

Commit 7b37eb0

Browse files
authored
Make TorchAO cpp/Python extension
Differential Revision: D69634772 Pull Request resolved: #1719
1 parent f2e8f56 commit 7b37eb0

File tree

6 files changed

+42
-51
lines changed

6 files changed

+42
-51
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchao.utils import (
2424
TORCH_VERSION_AT_LEAST_2_5,
2525
TORCH_VERSION_AT_LEAST_2_6,
26+
is_fbcode,
2627
is_sm_at_least_89,
2728
)
2829

@@ -213,6 +214,8 @@ class TestAffineQuantizedBasic(TestCase):
213214
@common_utils.parametrize("device", COMMON_DEVICES)
214215
@common_utils.parametrize("dtype", COMMON_DTYPES)
215216
def test_flatten_unflatten(self, device, dtype):
217+
if device == "cuda" and dtype == torch.bfloat16 and is_fbcode():
218+
raise unittest.SkipTest("TODO: Failing for cuda + bfloat16 in fbcode")
216219
apply_quant_list = get_quantization_functions(False, True, device)
217220
for apply_quant in apply_quant_list:
218221
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)

test/quantization/test_marlin_qqq.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import copy
2-
import unittest
32

43
import pytest
54
import torch
@@ -19,13 +18,9 @@
1918
MappingType,
2019
choose_qparams_and_quantize_affine_qqq,
2120
)
22-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode
21+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2322

2423

25-
@unittest.skipIf(
26-
is_fbcode(),
27-
"Skipping the test in fbcode since we don't have TARGET file for kernels",
28-
)
2924
class TestMarlinQQQ(TestCase):
3025
def setUp(self):
3126
super().setUp()

test/test_ops.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
)
1919
from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq
2020
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
21-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode
22-
23-
if is_fbcode():
24-
pytest.skip(
25-
"Skipping the test in fbcode since we don't have TARGET file for kernels"
26-
)
21+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
2722

2823
try:
2924
import torchao.ops

torchao/__init__.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
"ignore", message="Failed to initialize NumPy: No module named 'numpy'"
1010
)
1111

12-
1312
# We use this "hack" to set torchao.__version__ correctly
1413
# the version of ao is dependent on environment variables for multiple architectures
1514
# For local development this will default to whatever is version.txt
@@ -21,34 +20,28 @@
2120
except PackageNotFoundError:
2221
__version__ = "unknown" # In case this logic breaks don't break the build
2322

24-
_IS_FBCODE = (
25-
hasattr(torch._utils_internal, "IS_FBSOURCE") and torch._utils_internal.IS_FBSOURCE
26-
)
27-
if not _IS_FBCODE:
28-
try:
29-
from pathlib import Path
30-
31-
so_files = list(Path(__file__).parent.glob("_C*.so"))
32-
if len(so_files) > 0:
33-
assert (
34-
len(so_files) == 1
35-
), f"Expected one _C*.so file, found {len(so_files)}"
36-
torch.ops.load_library(so_files[0])
37-
from . import ops
38-
39-
# The following library contains CPU kernels from torchao/experimental
40-
# They are built automatically by ao/setup.py if on an ARM machine.
41-
# They can also be built outside of the torchao install process by
42-
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
43-
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
44-
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
45-
if len(experimental_lib) > 0:
46-
assert (
47-
len(experimental_lib) == 1
48-
), f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
49-
torch.ops.load_library(experimental_lib[0])
50-
except:
51-
logging.debug("Skipping import of cpp extensions")
23+
try:
24+
from pathlib import Path
25+
26+
so_files = list(Path(__file__).parent.glob("_C*.so"))
27+
if len(so_files) > 0:
28+
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
29+
torch.ops.load_library(str(so_files[0]))
30+
from . import ops
31+
32+
# The following library contains CPU kernels from torchao/experimental
33+
# They are built automatically by ao/setup.py if on an ARM machine.
34+
# They can also be built outside of the torchao install process by
35+
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
36+
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
37+
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
38+
if len(experimental_lib) > 0:
39+
assert (
40+
len(experimental_lib) == 1
41+
), f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
42+
torch.ops.load_library(str(experimental_lib[0]))
43+
except:
44+
logging.debug("Skipping import of cpp extensions")
5245

5346
from torchao.quantization import (
5447
autoquant,
@@ -64,6 +57,3 @@
6457
"testing",
6558
"ops",
6659
]
67-
68-
# test-pytorchbot
69-
# test-codev

torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@ rowwise_scaled_linear_cutlass_s4s4(
1414
" for xq and ", wq.dtype(), " for wq is not supported");
1515

1616
// Dispatch to appropriate kernel template.
17-
using ElementA = cutlass::int4b_t;
18-
using ElementB = cutlass::int4b_t;
19-
return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
17+
#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
18+
// We get ElementA/ElementB types from the header
19+
return rowwise_scaled_linear_cutlass<cutlass::int4b_t, cutlass::int4b_t>(
2020
xq, x_scale, wq, w_scale, bias);
21+
#else
22+
TORCH_CHECK(false, "CUTLASS kernels not built - rowwise_scaled_linear_cutlass_s4s4 not available");
23+
return at::Tensor{};
24+
#endif
2125
}
2226

2327
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {

torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <torch/library.h>
2-
32
#include "rowwise_scaled_linear_cutlass.cuh"
43

54
namespace torchao {
@@ -13,11 +12,16 @@ rowwise_scaled_linear_cutlass_s8s4(
1312
__func__, " : The input datatypes combination ", xq.dtype(),
1413
" for xq and ", wq.dtype(), " for wq is not supported");
1514

16-
// Dispatch to appropriate kernel template.
15+
#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
16+
// Define ElementA as int8_t since it's a standard type
1717
using ElementA = int8_t;
18-
using ElementB = cutlass::int4b_t;
19-
return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
18+
// ElementB comes from cutlass header
19+
return rowwise_scaled_linear_cutlass<ElementA, cutlass::int4b_t>(
2020
xq, x_scale, wq, w_scale, bias);
21+
#else
22+
TORCH_CHECK(false, "CUTLASS kernels not built - rowwise_scaled_linear_cutlass_s8s4 not available");
23+
return at::Tensor{};
24+
#endif
2125
}
2226

2327
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {

0 commit comments

Comments
 (0)