Skip to content

Commit 1eb4d33

Browse files
authored
Migrate to int args (#1846)
* init * up * up * lint * up * lint * init * up * up * lint * up * up * up * up
1 parent ed6ec9c commit 1eb4d33

24 files changed

+117
-563
lines changed

.github/workflows/torchao_experimental_test.yml

+10-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ jobs:
3333
- name: Install requirements
3434
run: |
3535
conda activate venv
36-
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu"
36+
# Install executorch first because it installs its own version
37+
# of torch and torchao, which we do not want to use
38+
pip install executorch
39+
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall
3740
pip install numpy
3841
pip install pytest
3942
pip install parameterized
@@ -57,6 +60,12 @@ jobs:
5760
sh build_and_run_tests.sh
5861
rm -rf /tmp/cmake-out
5962
popd
63+
- name: ET ops build
64+
run: |
65+
conda activate venv
66+
pushd torchao/experimental
67+
sh build_torchao_ops.sh executorch
68+
popd
6069
6170
test-mps-ops:
6271
strategy:

torchao/experimental/build_torchao_ops.sh

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ export CMAKE_OUT=cmake-out
2121
cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
2222
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
2323
-DTORCHAO_BUILD_EXECUTORCH_OPS="${TORCHAO_BUILD_EXECUTORCH_OPS}" \
24+
-DTORCHAO_BUILD_CPU_AARCH64=ON \
2425
-S . \
2526
-B ${CMAKE_OUT}
2627
cmake --build ${CMAKE_OUT} -j 16 --target install --config Release

torchao/experimental/ops/embedding_xbit/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS)
2727
# libextension_threadpool.a
2828
# libcpuinfo.a
2929
# libpthreadpool.a
30+
if(NOT DEFINED EXECUTORCH_INCLUDE_DIRS AND NOT DEFINED EXECUTORCH_LIBRARIES)
31+
message(WARNING "EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES are not defined. Looking for ExecuTorch.")
32+
find_package(ExecuTorch HINTS ${CMAKE_PREFIX_PATH}/executorch/share/cmake)
33+
endif()
3034
add_library(torchao_ops_embedding_xbit_executorch OBJECT
3135
op_embedding_xbit_executorch.cpp
3236
)

torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt

+6-2
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS)
4141
# libextension_threadpool.a
4242
# libcpuinfo.a
4343
# libpthreadpool.a
44+
if(NOT DEFINED EXECUTORCH_INCLUDE_DIRS AND NOT DEFINED EXECUTORCH_LIBRARIES)
45+
message(WARNING "EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES are not defined. Looking for ExecuTorch.")
46+
find_package(ExecuTorch HINTS ${CMAKE_PREFIX_PATH}/executorch/share/cmake)
47+
endif()
4448
# find_package(ExecuTorch HINTS ${CMAKE_PREFIX_PATH}/executorch/share/cmake)
45-
file(GLOB _SRCS "${CMAKE_CURRENT_SOURCE_DIR}/op_linear_8bit_act_xbit_weight_executorch/*.cpp")
49+
# file(GLOB _SRCS "${CMAKE_CURRENT_SOURCE_DIR}/op_linear_8bit_act_xbit_weight_executorch/*.cpp")
4650
add_library(torchao_ops_linear_8bit_act_xbit_weight_executorch OBJECT
4751
linear_8bit_act_xbit_weight.cpp
48-
${_SRCS}
52+
op_linear_8bit_act_xbit_weight_executorch.cpp
4953
)
5054
target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_executorch executorch)
5155
target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}")

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h

+15-38
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,7 @@ Tensor pack_weights_cpu(const Tensor &weight_qvals, const Tensor &weight_scales,
8989
template <int weight_nbit>
9090
Tensor pack_weights_without_zeros_cpu(
9191
const Tensor &weight_qvals, const Tensor &weight_scales,
92-
// TODO(T200095131): convert to int64_t when supported by AOTI
93-
// group_size is a tensor with size (0, group_size)
94-
const Tensor &group_size_tensor) {
95-
int64_t group_size = group_size_tensor.size(1);
92+
const int64_t& group_size) {
9693
return pack_weights_cpu<weight_nbit,
9794
/*has_weight_zeros*/ false,
9895
/*has_bias*/ false>(weight_qvals, weight_scales,
@@ -105,10 +102,8 @@ template <int weight_nbit>
105102
Tensor pack_weights_with_zeros_cpu(
106103
const Tensor &weight_qvals, const Tensor &weight_scales,
107104
const Tensor &weight_zeros,
108-
// TODO(T200095131): convert to int64_t when supported by AOTI
109-
// group_size is a meta tensor with size (group_size)
110-
const Tensor &group_size_tensor) {
111-
int64_t group_size = group_size_tensor.size(1);
105+
const int64_t& group_size
106+
) {
112107
return pack_weights_cpu<weight_nbit,
113108
/*has_weight_zeros*/ true,
114109
/*has_bias*/ false>(weight_qvals, weight_scales,
@@ -145,10 +140,8 @@ Tensor pack_weights_meta(const Tensor &weight_qvals,
145140
template <int weight_nbit>
146141
Tensor pack_weights_without_zeros_meta(
147142
const Tensor &weight_qvals, const Tensor &weight_scales,
148-
// TODO(T200095131): convert to int64_t when supported by AOTI
149-
// group_size is a meta tensor with size (group_size)
150-
const Tensor &group_size_tensor) {
151-
int64_t group_size = group_size_tensor.size(1);
143+
const int64_t& group_size
144+
) {
152145
return pack_weights_meta<weight_nbit,
153146
/*has_weight_zeros*/ false,
154147
/*has_bias*/ false>(weight_qvals, weight_scales,
@@ -161,10 +154,8 @@ template <int weight_nbit>
161154
Tensor pack_weights_with_zeros_meta(
162155
const Tensor &weight_qvals, const Tensor &weight_scales,
163156
const Tensor &weight_zeros,
164-
// TODO(T200095131): convert to int64_t when supported by AOTI
165-
// group_size is a meta tensor with size (group_size)
166-
const Tensor &group_size_tensor) {
167-
int64_t group_size = group_size_tensor.size(1);
157+
const int64_t& group_size
158+
) {
168159
return pack_weights_meta<weight_nbit,
169160
/*has_weight_zeros*/ true,
170161
/*has_bias*/ false>(weight_qvals, weight_scales,
@@ -176,14 +167,8 @@ Tensor pack_weights_with_zeros_meta(
176167
template <int weight_nbit, bool has_weight_zeros>
177168
Tensor
178169
linear_out_cpu(const Tensor &activations, const Tensor &packed_weights,
179-
// TODO(T200095131): convert n_tensor, k_tensor,
180-
// group_size_tensor to int64_t when supported by AOTI Currently
181-
// they are tensors with size equal to (0, the int they wrap)
182-
const Tensor &group_size_tensor, const Tensor &n_tensor,
183-
const Tensor &k_tensor, Tensor &out) {
184-
int n = n_tensor.size(1);
185-
int k = k_tensor.size(1);
186-
int group_size = group_size_tensor.size(1);
170+
const int64_t& group_size, const int64_t& n,
171+
const int64_t& k, Tensor &out) {
187172
TORCHAO_CHECK(n >= 1, "n must be >= 1");
188173
TORCHAO_CHECK(k >= 1, "k must be >= 1");
189174
TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1");
@@ -261,15 +246,12 @@ linear_out_cpu(const Tensor &activations, const Tensor &packed_weights,
261246
template <int weight_nbit, bool has_weight_zeros>
262247
Tensor
263248
linear_cpu(const Tensor &activations, const Tensor &packed_weights,
264-
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
265-
// int64_t when supported by AOTI Currently they are tensors with
266-
// size equal to (0, the int they wrap)
267-
const Tensor &group_size_tensor, const Tensor &n_tensor,
268-
const Tensor &k_tensor) {
249+
const int64_t &group_size, const int64_t &n,
250+
const int64_t &k) {
269251
Tensor output_tensor = torch::empty({}, torch::kFloat32);
270252
linear_out_cpu<weight_nbit, has_weight_zeros>(activations, packed_weights,
271-
group_size_tensor, n_tensor,
272-
k_tensor, output_tensor);
253+
group_size, n,
254+
k, output_tensor);
273255
return output_tensor;
274256
}
275257
#endif // USE_ATEN
@@ -278,13 +260,8 @@ linear_cpu(const Tensor &activations, const Tensor &packed_weights,
278260
template <int weight_nbit, bool has_weight_zeros>
279261
Tensor linear_meta(
280262
const Tensor &activations, const Tensor &packed_weights,
281-
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
282-
// int64_t when supported by AOTI
283-
// Currently they are tensors with size equal to (0, the int they wrap)
284-
const Tensor &group_size_tensor, const Tensor &n_tensor,
285-
const Tensor &k_tensor) {
286-
int n = n_tensor.size(1);
287-
int k = k_tensor.size(1);
263+
const int64_t &group_size, const int64_t &n,
264+
const int64_t &k) {
288265
TORCHAO_CHECK(n >= 1, "n must be >= 1");
289266
TORCHAO_CHECK(k >= 1, "k must be >= 1");
290267

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,22 @@
99
#define DEFINE_OP(weight_nbit) \
1010
m.def( \
1111
"_pack_8bit_act_" #weight_nbit \
12-
"bit0zp_weight(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); \
12+
"bit0zp_weight(Tensor weight_qvals, Tensor weight_scales, int group_size) -> Tensor"); \
1313
m.def( \
1414
"_pack_8bit_act_" #weight_nbit \
15-
"bit_weight(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor"); \
15+
"bit_weight(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, int group_size) -> Tensor"); \
1616
m.def( \
1717
"_linear_8bit_act_" #weight_nbit \
18-
"bit0zp_weight(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k) -> Tensor"); \
18+
"bit0zp_weight(Tensor activations, Tensor packed_weights, int group_size, int n, int k) -> Tensor"); \
1919
m.def( \
2020
"_linear_8bit_act_" #weight_nbit \
21-
"bit_weight(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k) -> Tensor"); \
21+
"bit_weight(Tensor activations, Tensor packed_weights, int group_size, int n, int k) -> Tensor"); \
2222
m.def( \
2323
"_linear_8bit_act_" #weight_nbit \
24-
"bit0zp_weight.out(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k, *, Tensor(a!) out) -> Tensor(a!)"); \
24+
"bit0zp_weight.out(Tensor activations, Tensor packed_weights, int group_size, int n, int k, *, Tensor(a!) out) -> Tensor(a!)"); \
2525
m.def( \
2626
"_linear_8bit_act_" #weight_nbit \
27-
"bit_weight.out(Tensor activations, Tensor packed_weights, Tensor group_size, Tensor n, Tensor k, *, Tensor(a!) out) -> Tensor(a!)")
27+
"bit_weight.out(Tensor activations, Tensor packed_weights, int group_size, int n, int k, *, Tensor(a!) out) -> Tensor(a!)")
2828

2929
#define DEFINE_CPU_IMPL(weight_nbit) \
3030
m.impl( \
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>
2+
3+
#define DEFINE_OP(weight_nbit) \
4+
Tensor _op_out_0zp_##weight_nbit( \
5+
RuntimeContext &ctx, const Tensor &activations, \
6+
const Tensor &packed_weights, const int64_t &group_size, \
7+
const int64_t &n, const int64_t &k, Tensor &out) { \
8+
(void)ctx; \
9+
linear_out_cpu<weight_nbit, false>(activations, packed_weights, \
10+
group_size, n, k, out); \
11+
return out; \
12+
} \
13+
Tensor _op_out_zp_##weight_nbit( \
14+
RuntimeContext &ctx, const Tensor &activations, \
15+
const Tensor &packed_weights, const int64_t &group_size, \
16+
const int64_t &n, const int64_t &k, Tensor &out) { \
17+
(void)ctx; \
18+
linear_out_cpu<weight_nbit, true>(activations, packed_weights, group_size, \
19+
n, k, out); \
20+
return out; \
21+
}
22+
23+
#define REGISTER_0ZP(weight_nbit) \
24+
EXECUTORCH_LIBRARY(torchao, \
25+
"_linear_8bit_act_" #weight_nbit "bit0zp_weight.out", \
26+
_op_out_0zp_##weight_nbit)
27+
28+
#define REGISTER_ZP(weight_nbit) \
29+
EXECUTORCH_LIBRARY(torchao, \
30+
"_linear_8bit_act_" #weight_nbit "bit_weight.out", \
31+
_op_out_zp_##weight_nbit)
32+
33+
// This looks a bit ridiculous, but I could not get it to compile with two
34+
// EXECUTORCH_LIBRARY nested inside DEFINE_OP
35+
DEFINE_OP(1)
36+
REGISTER_0ZP(1);
37+
REGISTER_ZP(1);
38+
39+
DEFINE_OP(2)
40+
REGISTER_0ZP(2);
41+
REGISTER_ZP(2);
42+
43+
DEFINE_OP(3)
44+
REGISTER_0ZP(3);
45+
REGISTER_ZP(3);
46+
47+
DEFINE_OP(4)
48+
REGISTER_0ZP(4);
49+
REGISTER_ZP(4);
50+
51+
DEFINE_OP(5)
52+
REGISTER_0ZP(5);
53+
REGISTER_ZP(5);
54+
55+
DEFINE_OP(6)
56+
REGISTER_0ZP(6);
57+
REGISTER_ZP(6);
58+
59+
DEFINE_OP(7)
60+
REGISTER_0ZP(7);
61+
REGISTER_ZP(7);
62+
63+
DEFINE_OP(8)
64+
REGISTER_0ZP(8);
65+
REGISTER_ZP(8);

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w1s.cpp

-29
This file was deleted.

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w1sz.cpp

-29
This file was deleted.

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2s.cpp

-29
This file was deleted.

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w2sz.cpp

-29
This file was deleted.

0 commit comments

Comments
 (0)