|
4 | 4 | #include <ATen/core/Tensor.h> |
5 | 5 | #include <torch/library.h> |
6 | 6 | #include <ATen/native/mkldnn/Linear.h> |
7 | | -#include <ATen/native/Resize.h> |
8 | 7 |
|
9 | 8 | #ifndef AT_PER_OPERATOR_HEADERS |
10 | 9 | #include <ATen/Functions.h> |
@@ -47,20 +46,9 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward( |
47 | 46 | TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support"); |
48 | 47 | } |
49 | 48 |
|
50 | | -Tensor& |
51 | | -mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, |
52 | | - const Tensor& scale_a, |
53 | | - const Tensor& scale_b, |
54 | | - const std::optional<at::Tensor>& bias, |
55 | | - const std::optional<at::Tensor>& scale_result, |
56 | | - std::optional<c10::ScalarType> out_dtype, |
57 | | - bool use_fast_accum, |
58 | | - Tensor& out) { |
59 | | - TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support"); |
60 | | -} |
61 | | - |
62 | 49 | } // namespace at::native |
63 | 50 |
|
| 51 | + |
64 | 52 | #else // AT_MKLDNN_ENABLED |
65 | 53 |
|
66 | 54 | #include <ATen/native/mkldnn/MKLDNNCommon.h> |
@@ -459,119 +447,6 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) { |
459 | 447 | TORCH_FN(mkldnn_linear_pointwise_binary)); |
460 | 448 | } |
461 | 449 |
|
462 | | -Tensor& |
463 | | -mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, |
464 | | - const Tensor& scale_a, |
465 | | - const Tensor& scale_b, |
466 | | - const std::optional<at::Tensor>& bias, |
467 | | - const std::optional<at::Tensor>& scale_result, |
468 | | - std::optional<c10::ScalarType> out_dtype, |
469 | | - bool use_fast_accum, |
470 | | - Tensor& out) { |
471 | | - TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); |
472 | | - TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); |
473 | | - TORCH_CHECK( |
474 | | - mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", |
475 | | - mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); |
476 | | - |
477 | | - TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend."); |
478 | | - TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], |
479 | | - " but got ", bias->numel()); |
480 | | - |
481 | | - // Check types |
482 | | - TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); |
483 | | - TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); |
484 | | - TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); |
485 | | - // TODO: This check of mat1 and mat2 must have the same data type will be removed after oneDNN v3.6. |
486 | | - TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "Expected mat1 and mat2 must have the same data type"); |
487 | | - |
488 | | - // Validation checks have passed lets resize the output to actual size |
489 | | - auto mat1_c = mat1.contiguous(); |
490 | | - auto mat2_c = mat2.contiguous(); |
491 | | - IntArrayRef mat1_sizes = mat1_c.sizes(); |
492 | | - IntArrayRef mat2_sizes = mat2_c.sizes(); |
493 | | - at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); |
494 | | - |
495 | | - float input_scale = scale_a.item<float>(); |
496 | | - float weight_scale = scale_b.item<float>(); |
497 | | - auto src = at::native::itensor_view_from_dense(mat1_c); |
498 | | - auto weight_t = at::native::itensor_view_from_dense(mat2_c); |
499 | | - bool with_bias = bias.has_value(); |
500 | | - int64_t K = mat1_sizes[1], M = mat1_sizes[0], |
501 | | - N = mat2_sizes[1]; |
502 | | - |
503 | | - std::vector<int64_t> src_dims = {M, K}; |
504 | | - std::vector<int64_t> weight_dims = {K, N}; |
505 | | - std::vector<int64_t> dst_dims = {M, N}; |
506 | | - |
507 | | - ideep::tensor dst = at::native::itensor_view_from_dense(out); |
508 | | - auto src_desc = ideep::tensor::desc( |
509 | | - src_dims, |
510 | | - get_mkldnn_dtype(mat1.scalar_type()), |
511 | | - ideep::format_tag::any); |
512 | | - auto weights_desc = ideep::tensor::desc( |
513 | | - weight_dims, |
514 | | - get_mkldnn_dtype(mat2.scalar_type()), |
515 | | - ideep::format_tag::any); |
516 | | - auto dst_desc = ideep::tensor::desc( |
517 | | - dst_dims, |
518 | | - get_mkldnn_dtype(out.scalar_type()), |
519 | | - ideep::format_tag::any); |
520 | | - ideep::tensor onednn_bias; |
521 | | - if (with_bias) { |
522 | | - auto bias_value = bias.value(); |
523 | | - if (bias_value.dim() == 1) { |
524 | | - auto b_reshape = bias_value.reshape({1, bias_value.size(0)}); |
525 | | - onednn_bias = at::native::itensor_view_from_dense(b_reshape); |
526 | | - } else { |
527 | | - onednn_bias = at::native::itensor_view_from_dense(bias_value); |
528 | | - } |
529 | | - } |
530 | | - auto bias_desc = ideep::tensor::desc(); |
531 | | - if (with_bias) { |
532 | | - bias_desc = ideep::tensor::desc(onednn_bias.get_dims(), |
533 | | - get_mkldnn_dtype(bias.value().scalar_type()), |
534 | | - ideep::format_tag::any); |
535 | | - } |
536 | | - auto op_attr = ideep::attr_t(); |
537 | | - if (input_scale != 1.0f) { |
538 | | - op_attr.set_scales_mask(DNNL_ARG_SRC, 0); |
539 | | - } |
540 | | - if (weight_scale != 1.0f) { |
541 | | - op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); |
542 | | - } |
543 | | - |
544 | | - op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); |
545 | | - auto engine = ideep::engine::cpu_engine(); |
546 | | - dnnl::matmul::primitive_desc primitive_desc = with_bias |
547 | | - ? dnnl::matmul::primitive_desc( |
548 | | - engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr) |
549 | | - : dnnl::matmul::primitive_desc( |
550 | | - engine, src_desc, weights_desc, dst_desc, op_attr); |
551 | | - auto primitive = dnnl::matmul(primitive_desc); |
552 | | - |
553 | | - // Prepare args and execute primitive |
554 | | - ideep::tensor scratchpad(primitive_desc.scratchpad_desc()); |
555 | | - ideep::exec_args args; |
556 | | - args.insert({DNNL_ARG_SRC, src}); |
557 | | - args.insert({DNNL_ARG_WEIGHTS, weight_t}); |
558 | | - args.insert({DNNL_ARG_DST, dst}); |
559 | | - args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); |
560 | | - if (with_bias) { |
561 | | - args.insert({DNNL_ARG_BIAS, onednn_bias}); |
562 | | - } |
563 | | - ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale)); |
564 | | - ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale)); |
565 | | - |
566 | | - if (input_scale != 1.0f) { |
567 | | - args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t}); |
568 | | - } |
569 | | - args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t}); |
570 | | - |
571 | | - primitive.execute(ideep::stream::default_stream(), args); |
572 | | - return out; |
573 | | -} |
574 | | - |
575 | 450 | } // namespace at |
576 | 451 |
|
577 | 452 | #endif // AT_MKLDNN_ENABLED |
0 commit comments