Skip to content

Commit 49e8f9c

Browse files
Revert "Add torch._scaled_mm for CPU (pytorch#139975)"
This reverts commit 22fae4c. Reverted pytorch#139975 on behalf of https://github.com/huydhn due to third time is the charm ([comment](pytorch#139975 (comment)))
1 parent 59a0813 commit 49e8f9c

File tree

12 files changed

+586
-915
lines changed

12 files changed

+586
-915
lines changed

aten/src/ATen/native/Blas.cpp

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77
#include <ATen/Config.h>
88

99
#include <ATen/native/mkldnn/Matmul.h>
10-
#include <ATen/native/mkldnn/Linear.h>
11-
#include <ATen/native/Resize.h>
12-
#if !defined(__s390x__) && !defined(__powerpc__)
13-
#include <cpuinfo.h>
14-
#endif
1510

1611
#ifndef AT_PER_OPERATOR_HEADERS
1712
#include <ATen/CPUFunctions.h>
@@ -29,9 +24,6 @@
2924
#include <ATen/ops/mv_native.h>
3025
#include <ATen/ops/scalar_tensor_native.h>
3126
#include <ATen/ops/vdot_native.h>
32-
#include <ATen/ops/_scaled_mm_native.h>
33-
#include <ATen/ops/mul.h>
34-
#include <ATen/ops/matmul.h>
3527
#endif
3628

3729
namespace at::meta {
@@ -230,79 +222,4 @@ Tensor vdot(const Tensor &self, const Tensor &other){
230222

231223
}
232224

233-
static Tensor&
234-
_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
235-
const Tensor& scale_a,
236-
const Tensor& scale_b,
237-
const std::optional<at::Tensor>& bias,
238-
const std::optional<at::Tensor>& scale_result,
239-
std::optional<c10::ScalarType> out_dtype,
240-
bool use_fast_accum,
241-
Tensor& out) {
242-
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
243-
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
244-
TORCH_CHECK(
245-
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
246-
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
247-
248-
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
249-
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
250-
" but got ", bias->numel());
251-
252-
// Check types
253-
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
254-
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
255-
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
256-
257-
auto mat1_c = mat1.contiguous();
258-
auto mat2_c = mat2.contiguous();
259-
IntArrayRef mat1_sizes = mat1_c.sizes();
260-
IntArrayRef mat2_sizes = mat2_c.sizes();
261-
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
262-
263-
float input_scale = scale_a.item<float>();
264-
float weight_scale = scale_b.item<float>();
265-
auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale);
266-
auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale);
267-
auto out_tmp = at::matmul(fp32_mat1, fp32_mat2);
268-
if (bias) {
269-
out_tmp.add_(bias.value());
270-
}
271-
out_tmp = out_tmp.to(out.scalar_type());
272-
out.copy_(out_tmp);
273-
return out;
274-
}
275-
276-
Tensor&
277-
_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2,
278-
const Tensor& scale_a,
279-
const Tensor& scale_b,
280-
const std::optional<at::Tensor>& bias,
281-
const std::optional<at::Tensor>& scale_result,
282-
std::optional<c10::ScalarType> out_dtype,
283-
bool use_fast_accum,
284-
Tensor& out) {
285-
#if AT_MKLDNN_ENABLED() && (IDEEP_VERSION_MAJOR >= 3 && IDEEP_VERSION_MINOR >= 5)
286-
if (at::globalContext().userEnabledMkldnn() && cpuinfo_has_x86_amx_int8()) {
287-
return mkldnn_scaled_mm(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
288-
} else
289-
#endif
290-
{
291-
return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
292-
}
293-
}
294-
295-
Tensor
296-
_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b,
297-
const Tensor& scale_a,
298-
const Tensor& scale_b,
299-
const std::optional<at::Tensor>& bias,
300-
const std::optional<at::Tensor>& scale_result,
301-
std::optional<c10::ScalarType> out_dtype,
302-
bool use_fast_accum) {
303-
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
304-
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
305-
return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
306-
}
307-
308225
} // namespace at::native

aten/src/ATen/native/mkldnn/Linear.cpp

Lines changed: 1 addition & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <ATen/core/Tensor.h>
55
#include <torch/library.h>
66
#include <ATen/native/mkldnn/Linear.h>
7-
#include <ATen/native/Resize.h>
87

98
#ifndef AT_PER_OPERATOR_HEADERS
109
#include <ATen/Functions.h>
@@ -47,20 +46,9 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
4746
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
4847
}
4948

50-
Tensor&
51-
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
52-
const Tensor& scale_a,
53-
const Tensor& scale_b,
54-
const std::optional<at::Tensor>& bias,
55-
const std::optional<at::Tensor>& scale_result,
56-
std::optional<c10::ScalarType> out_dtype,
57-
bool use_fast_accum,
58-
Tensor& out) {
59-
TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support");
60-
}
61-
6249
} // namespace at::native
6350

51+
6452
#else // AT_MKLDNN_ENABLED
6553

6654
#include <ATen/native/mkldnn/MKLDNNCommon.h>
@@ -459,119 +447,6 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
459447
TORCH_FN(mkldnn_linear_pointwise_binary));
460448
}
461449

462-
Tensor&
463-
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
464-
const Tensor& scale_a,
465-
const Tensor& scale_b,
466-
const std::optional<at::Tensor>& bias,
467-
const std::optional<at::Tensor>& scale_result,
468-
std::optional<c10::ScalarType> out_dtype,
469-
bool use_fast_accum,
470-
Tensor& out) {
471-
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
472-
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
473-
TORCH_CHECK(
474-
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
475-
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
476-
477-
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
478-
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
479-
" but got ", bias->numel());
480-
481-
// Check types
482-
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
483-
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
484-
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
485-
// TODO: This check of mat1 and mat2 must have the same data type will be removed after oneDNN v3.6.
486-
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "Expected mat1 and mat2 must have the same data type");
487-
488-
// Validation checks have passed lets resize the output to actual size
489-
auto mat1_c = mat1.contiguous();
490-
auto mat2_c = mat2.contiguous();
491-
IntArrayRef mat1_sizes = mat1_c.sizes();
492-
IntArrayRef mat2_sizes = mat2_c.sizes();
493-
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
494-
495-
float input_scale = scale_a.item<float>();
496-
float weight_scale = scale_b.item<float>();
497-
auto src = at::native::itensor_view_from_dense(mat1_c);
498-
auto weight_t = at::native::itensor_view_from_dense(mat2_c);
499-
bool with_bias = bias.has_value();
500-
int64_t K = mat1_sizes[1], M = mat1_sizes[0],
501-
N = mat2_sizes[1];
502-
503-
std::vector<int64_t> src_dims = {M, K};
504-
std::vector<int64_t> weight_dims = {K, N};
505-
std::vector<int64_t> dst_dims = {M, N};
506-
507-
ideep::tensor dst = at::native::itensor_view_from_dense(out);
508-
auto src_desc = ideep::tensor::desc(
509-
src_dims,
510-
get_mkldnn_dtype(mat1.scalar_type()),
511-
ideep::format_tag::any);
512-
auto weights_desc = ideep::tensor::desc(
513-
weight_dims,
514-
get_mkldnn_dtype(mat2.scalar_type()),
515-
ideep::format_tag::any);
516-
auto dst_desc = ideep::tensor::desc(
517-
dst_dims,
518-
get_mkldnn_dtype(out.scalar_type()),
519-
ideep::format_tag::any);
520-
ideep::tensor onednn_bias;
521-
if (with_bias) {
522-
auto bias_value = bias.value();
523-
if (bias_value.dim() == 1) {
524-
auto b_reshape = bias_value.reshape({1, bias_value.size(0)});
525-
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
526-
} else {
527-
onednn_bias = at::native::itensor_view_from_dense(bias_value);
528-
}
529-
}
530-
auto bias_desc = ideep::tensor::desc();
531-
if (with_bias) {
532-
bias_desc = ideep::tensor::desc(onednn_bias.get_dims(),
533-
get_mkldnn_dtype(bias.value().scalar_type()),
534-
ideep::format_tag::any);
535-
}
536-
auto op_attr = ideep::attr_t();
537-
if (input_scale != 1.0f) {
538-
op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
539-
}
540-
if (weight_scale != 1.0f) {
541-
op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
542-
}
543-
544-
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
545-
auto engine = ideep::engine::cpu_engine();
546-
dnnl::matmul::primitive_desc primitive_desc = with_bias
547-
? dnnl::matmul::primitive_desc(
548-
engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr)
549-
: dnnl::matmul::primitive_desc(
550-
engine, src_desc, weights_desc, dst_desc, op_attr);
551-
auto primitive = dnnl::matmul(primitive_desc);
552-
553-
// Prepare args and execute primitive
554-
ideep::tensor scratchpad(primitive_desc.scratchpad_desc());
555-
ideep::exec_args args;
556-
args.insert({DNNL_ARG_SRC, src});
557-
args.insert({DNNL_ARG_WEIGHTS, weight_t});
558-
args.insert({DNNL_ARG_DST, dst});
559-
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
560-
if (with_bias) {
561-
args.insert({DNNL_ARG_BIAS, onednn_bias});
562-
}
563-
ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale));
564-
ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale));
565-
566-
if (input_scale != 1.0f) {
567-
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
568-
}
569-
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
570-
571-
primitive.execute(ideep::stream::default_stream(), args);
572-
return out;
573-
}
574-
575450
} // namespace at
576451

577452
#endif // AT_MKLDNN_ENABLED

aten/src/ATen/native/mkldnn/Linear.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,3 @@ C10_API Tensor mkl_linear(
3535
} // namespace at
3636

3737
#endif // AT_MKLDNN_ENABLED()
38-
39-
namespace at::native {
40-
Tensor&
41-
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
42-
const Tensor& scale_a,
43-
const Tensor& scale_b,
44-
const std::optional<at::Tensor>& bias,
45-
const std::optional<at::Tensor>& scale_result,
46-
std::optional<c10::ScalarType> out_dtype,
47-
bool use_fast_accum,
48-
Tensor& out);
49-
} // namespace at::native

aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,6 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
5757
return ideep::tensor::data_type::bf16;
5858
case ScalarType::Half:
5959
return ideep::tensor::data_type::f16;
60-
case ScalarType::Float8_e4m3fn:
61-
return ideep::tensor::data_type::f8_e4m3;
62-
case ScalarType::Float8_e5m2:
63-
return ideep::tensor::data_type::f8_e5m2;
6460
default:
6561
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
6662
}
@@ -165,24 +161,8 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data
165161
const_cast<void*>(tensor.const_data_ptr()) :
166162
tensor.data_ptr()};
167163
}
168-
else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
169-
return {{tensor.sizes().vec(),
170-
ideep::tensor::data_type::f8_e4m3,
171-
tensor.strides().vec()},
172-
from_const_data_ptr ?
173-
const_cast<void*>(tensor.const_data_ptr()) :
174-
tensor.data_ptr()};
175-
}
176-
else if (tensor.scalar_type() == ScalarType::Float8_e5m2) {
177-
return {{tensor.sizes().vec(),
178-
ideep::tensor::data_type::f8_e5m2,
179-
tensor.strides().vec()},
180-
from_const_data_ptr ?
181-
const_cast<void*>(tensor.const_data_ptr()) :
182-
tensor.data_ptr()};
183-
}
184164
else {
185-
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input");
165+
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input");
186166
}
187167
}
188168

aten/src/ATen/native/native_functions.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7066,13 +7066,11 @@
70667066
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
70677067
variants: function
70687068
dispatch:
7069-
CPU: _scaled_mm_cpu
70707069
CUDA: _scaled_mm_cuda
70717070

70727071
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
70737072
variants: function
70747073
dispatch:
7075-
CPU: _scaled_mm_out_cpu
70767074
CUDA: _scaled_mm_out_cuda
70777075

70787076
# NOTE [ Sparse: autograd and API ]

0 commit comments

Comments
 (0)