-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy pathop_copy.cpp
103 lines (84 loc) · 2.96 KB
/
op_copy.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstring>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
namespace torch {
namespace executor {
namespace native {
using Tensor = executorch::aten::Tensor;
// copy.out(const Tensor& in, const Tensor& src, bool non_blocking, Tensor(a!)
// out) -> Tensor(a!), see caffe2/aten/src/ATen/native/Copy.cpp
// TODO: We actually shouldn't see this op with the proper functionalization,
// and this op needs to be deleted
Tensor& copy_out(
KernelRuntimeContext& ctx,
const Tensor& in,
const Tensor& src,
bool non_blocking,
Tensor& out) {
(void)ctx;
// Right now we only support blocking data transfer
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, out);
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
ET_KERNEL_CHECK(
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, out);
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "copy.out";
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](ET_UNUSED const auto _, const auto val_src) { return val_src; },
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
src,
utils::SupportedTensorDtypes::REALHBBF16,
out);
});
return out;
}
Tensor& copy_(
KernelRuntimeContext& ctx,
Tensor& in,
const Tensor& src,
bool non_blocking) {
(void)ctx;
// Right now we only support blocking data transfer
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, in);
ET_KERNEL_CHECK(
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, in);
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, src), InvalidArgument, in);
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "copy_";
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](ET_UNUSED const auto _, const auto val_src) { return val_src; },
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
src,
utils::SupportedTensorDtypes::REALHBBF16,
in);
});
return in;
}
} // namespace native
} // namespace executor
} // namespace torch