@@ -89,10 +89,7 @@ Tensor pack_weights_cpu(const Tensor &weight_qvals, const Tensor &weight_scales,
89
89
template <int weight_nbit>
90
90
Tensor pack_weights_without_zeros_cpu (
91
91
const Tensor &weight_qvals, const Tensor &weight_scales,
92
- // TODO(T200095131): convert to int64_t when supported by AOTI
93
- // group_size is a tensor with size (0, group_size)
94
- const Tensor &group_size_tensor) {
95
- int64_t group_size = group_size_tensor.size (1 );
92
+ const int64_t & group_size) {
96
93
return pack_weights_cpu<weight_nbit,
97
94
/* has_weight_zeros*/ false ,
98
95
/* has_bias*/ false >(weight_qvals, weight_scales,
@@ -105,10 +102,8 @@ template <int weight_nbit>
105
102
Tensor pack_weights_with_zeros_cpu (
106
103
const Tensor &weight_qvals, const Tensor &weight_scales,
107
104
const Tensor &weight_zeros,
108
- // TODO(T200095131): convert to int64_t when supported by AOTI
109
- // group_size is a meta tensor with size (group_size)
110
- const Tensor &group_size_tensor) {
111
- int64_t group_size = group_size_tensor.size (1 );
105
+ const int64_t & group_size
106
+ ) {
112
107
return pack_weights_cpu<weight_nbit,
113
108
/* has_weight_zeros*/ true ,
114
109
/* has_bias*/ false >(weight_qvals, weight_scales,
@@ -145,10 +140,8 @@ Tensor pack_weights_meta(const Tensor &weight_qvals,
145
140
template <int weight_nbit>
146
141
Tensor pack_weights_without_zeros_meta (
147
142
const Tensor &weight_qvals, const Tensor &weight_scales,
148
- // TODO(T200095131): convert to int64_t when supported by AOTI
149
- // group_size is a meta tensor with size (group_size)
150
- const Tensor &group_size_tensor) {
151
- int64_t group_size = group_size_tensor.size (1 );
143
+ const int64_t & group_size
144
+ ) {
152
145
return pack_weights_meta<weight_nbit,
153
146
/* has_weight_zeros*/ false ,
154
147
/* has_bias*/ false >(weight_qvals, weight_scales,
@@ -161,10 +154,8 @@ template <int weight_nbit>
161
154
Tensor pack_weights_with_zeros_meta (
162
155
const Tensor &weight_qvals, const Tensor &weight_scales,
163
156
const Tensor &weight_zeros,
164
- // TODO(T200095131): convert to int64_t when supported by AOTI
165
- // group_size is a meta tensor with size (group_size)
166
- const Tensor &group_size_tensor) {
167
- int64_t group_size = group_size_tensor.size (1 );
157
+ const int64_t & group_size
158
+ ) {
168
159
return pack_weights_meta<weight_nbit,
169
160
/* has_weight_zeros*/ true ,
170
161
/* has_bias*/ false >(weight_qvals, weight_scales,
@@ -176,14 +167,8 @@ Tensor pack_weights_with_zeros_meta(
176
167
template <int weight_nbit, bool has_weight_zeros>
177
168
Tensor
178
169
linear_out_cpu (const Tensor &activations, const Tensor &packed_weights,
179
- // TODO(T200095131): convert n_tensor, k_tensor,
180
- // group_size_tensor to int64_t when supported by AOTI Currently
181
- // they are tensors with size equal to (0, the int they wrap)
182
- const Tensor &group_size_tensor, const Tensor &n_tensor,
183
- const Tensor &k_tensor, Tensor &out) {
184
- int n = n_tensor.size (1 );
185
- int k = k_tensor.size (1 );
186
- int group_size = group_size_tensor.size (1 );
170
+ const int64_t & group_size, const int64_t & n,
171
+ const int64_t & k, Tensor &out) {
187
172
TORCHAO_CHECK (n >= 1 , " n must be >= 1" );
188
173
TORCHAO_CHECK (k >= 1 , " k must be >= 1" );
189
174
TORCHAO_CHECK (group_size >= 1 , " group_size must be >= 1" );
@@ -261,15 +246,12 @@ linear_out_cpu(const Tensor &activations, const Tensor &packed_weights,
261
246
template <int weight_nbit, bool has_weight_zeros>
262
247
Tensor
263
248
linear_cpu (const Tensor &activations, const Tensor &packed_weights,
264
- // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
265
- // int64_t when supported by AOTI Currently they are tensors with
266
- // size equal to (0, the int they wrap)
267
- const Tensor &group_size_tensor, const Tensor &n_tensor,
268
- const Tensor &k_tensor) {
249
+ const int64_t &group_size, const int64_t &n,
250
+ const int64_t &k) {
269
251
Tensor output_tensor = torch::empty ({}, torch::kFloat32 );
270
252
linear_out_cpu<weight_nbit, has_weight_zeros>(activations, packed_weights,
271
- group_size_tensor, n_tensor ,
272
- k_tensor , output_tensor);
253
+ group_size, n ,
254
+ k , output_tensor);
273
255
return output_tensor;
274
256
}
275
257
#endif // USE_ATEN
@@ -278,13 +260,8 @@ linear_cpu(const Tensor &activations, const Tensor &packed_weights,
278
260
template <int weight_nbit, bool has_weight_zeros>
279
261
Tensor linear_meta (
280
262
const Tensor &activations, const Tensor &packed_weights,
281
- // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
282
- // int64_t when supported by AOTI
283
- // Currently they are tensors with size equal to (0, the int they wrap)
284
- const Tensor &group_size_tensor, const Tensor &n_tensor,
285
- const Tensor &k_tensor) {
286
- int n = n_tensor.size (1 );
287
- int k = k_tensor.size (1 );
263
+ const int64_t &group_size, const int64_t &n,
264
+ const int64_t &k) {
288
265
TORCHAO_CHECK (n >= 1 , " n must be >= 1" );
289
266
TORCHAO_CHECK (k >= 1 , " k must be >= 1" );
290
267
0 commit comments