@@ -70,12 +70,13 @@ void check_linear_mps_args(
70
70
}
71
71
72
72
template <int nbit>
73
- Tensor linear_mps_kernel (
73
+ Tensor linear_mps_kernel_out (
74
74
const Tensor& A,
75
75
const Tensor& B,
76
76
int64_t group_size,
77
77
const Tensor& S,
78
- const Tensor& Z) {
78
+ const Tensor& Z,
79
+ Tensor& C) {
79
80
TORCH_CHECK (
80
81
A.is_mps (), __func__, " : A is on " , A.device (), " but expected on mps" );
81
82
TORCH_CHECK (
@@ -84,15 +85,15 @@ Tensor linear_mps_kernel(
84
85
S.is_mps (), __func__, " : S is on " , S.device (), " but expected on mps" );
85
86
TORCH_CHECK (
86
87
Z.is_mps (), __func__, " : Z is on " , Z.device (), " but expected on mps" );
88
+ TORCH_CHECK (
89
+ C.is_mps (), __func__, " : Z is on " , Z.device (), " but expected on mps" );
87
90
88
91
check_linear_mps_args<nbit>(A, B, group_size, S, Z);
89
92
90
93
auto M = A.size (0 );
91
94
auto N = B.size (0 );
92
95
auto K = A.size (1 );
93
96
94
- auto C = at::empty ({M, N}, A.options ());
95
-
96
97
LowBitQuantWeights<nbit>::linear (
97
98
getMTLBufferStorage (A),
98
99
getMTLBufferStorage (B),
@@ -108,6 +109,19 @@ Tensor linear_mps_kernel(
108
109
return C;
109
110
}
110
111
112
+ template <int nbit>
113
+ Tensor linear_mps_kernel (
114
+ const Tensor& A,
115
+ const Tensor& B,
116
+ int64_t group_size,
117
+ const Tensor& S,
118
+ const Tensor& Z) {
119
+ auto M = A.size (0 );
120
+ auto N = B.size (0 );
121
+ auto C = at::empty ({M, N}, A.options ());
122
+ return linear_mps_kernel_out<nbit>(A, B, group_size, S, Z, C);
123
+ }
124
+
111
125
template <int nbit>
112
126
Tensor linear_mps_kernel_meta (
113
127
const Tensor& A,
@@ -169,6 +183,20 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
169
183
" _linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor" );
170
184
m.def (
171
185
" _linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor" );
186
+ m.def (
187
+ " _linear_fp_act_1bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)" );
188
+ m.def (
189
+ " _linear_fp_act_2bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)" );
190
+ m.def (
191
+ " _linear_fp_act_3bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)" );
192
+ m.def (
193
+ " _linear_fp_act_4bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)" );
194
+ m.def (
195
+ " _linear_fp_act_5bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)" );
196
+ m.def (
197
+ " _linear_fp_act_6bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)" );
198
+ m.def (
199
+ " _linear_fp_act_7bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)" );
172
200
}
173
201
174
202
TORCH_LIBRARY_IMPL (torchao, CPU, m) {
@@ -189,6 +217,13 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
189
217
m.impl (" _linear_fp_act_5bit_weight" , &linear_mps_kernel<5 >);
190
218
m.impl (" _linear_fp_act_6bit_weight" , &linear_mps_kernel<6 >);
191
219
m.impl (" _linear_fp_act_7bit_weight" , &linear_mps_kernel<7 >);
220
+ m.impl (" _linear_fp_act_1bit_weight.out" , &linear_mps_kernel_out<1 >);
221
+ m.impl (" _linear_fp_act_2bit_weight.out" , &linear_mps_kernel_out<2 >);
222
+ m.impl (" _linear_fp_act_3bit_weight.out" , &linear_mps_kernel_out<3 >);
223
+ m.impl (" _linear_fp_act_4bit_weight.out" , &linear_mps_kernel_out<4 >);
224
+ m.impl (" _linear_fp_act_5bit_weight.out" , &linear_mps_kernel_out<5 >);
225
+ m.impl (" _linear_fp_act_6bit_weight.out" , &linear_mps_kernel_out<6 >);
226
+ m.impl (" _linear_fp_act_7bit_weight.out" , &linear_mps_kernel_out<7 >);
192
227
}
193
228
194
229
TORCH_LIBRARY_IMPL (torchao, Meta, m) {
0 commit comments