forked from IntelPython/sharded-array-for-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathIEWBinOp.cpp
123 lines (109 loc) · 4.27 KB
/
IEWBinOp.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// SPDX-License-Identifier: BSD-3-Clause
/*
Inplace elementwise binary ops.
*/
#include "sharpy/IEWBinOp.hpp"
#include "sharpy/Creator.hpp"
#include "sharpy/Deferred.hpp"
#include "sharpy/Factory.hpp"
#include "sharpy/NDArray.hpp"
#include "sharpy/Registry.hpp"
#include "sharpy/TypeDispatch.hpp"
#include "sharpy/jit/mlir.hpp"
#include <imex/Dialect/Dist/IR/DistOps.h>
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
#include <mlir/Dialect/Shape/IR/Shape.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
namespace SHARPY {
// convert id of our binop to id of imex::ndarray binop
static ::imex::ndarray::EWBinOpId sharpy2mlir(const IEWBinOpId bop) {
switch (bop) {
case __IADD__:
return ::imex::ndarray::ADD;
case __IAND__:
return ::imex::ndarray::BITWISE_AND;
case __IFLOORDIV__:
return ::imex::ndarray::FLOOR_DIVIDE;
case __ILSHIFT__:
return ::imex::ndarray::BITWISE_LEFT_SHIFT;
case __IMOD__:
return ::imex::ndarray::MODULO;
case __IMUL__:
return ::imex::ndarray::MULTIPLY;
case __IOR__:
return ::imex::ndarray::BITWISE_OR;
case __IPOW__:
return ::imex::ndarray::POWER;
case __IRSHIFT__:
return ::imex::ndarray::BITWISE_RIGHT_SHIFT;
case __ISUB__:
return ::imex::ndarray::SUBTRACT;
case __ITRUEDIV__:
return ::imex::ndarray::TRUE_DIVIDE;
case __IXOR__:
return ::imex::ndarray::BITWISE_XOR;
default:
throw std::invalid_argument(
"Unknown/invalid inplace elementwise binary operation");
}
}
struct DeferredIEWBinOp : public Deferred {
id_type _a;
id_type _b;
IEWBinOpId _op;
DeferredIEWBinOp() = default;
DeferredIEWBinOp(IEWBinOpId op, const array_i::future_type &a,
const array_i::future_type &b)
: Deferred(a.dtype(), a.shape(), a.device(), a.team()), _a(a.guid()),
_b(b.guid()), _op(op) {}
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
jit::DepManager &dm) override {
// FIXME the type of the result is based on a only
auto av = dm.getDependent(builder, Registry::get(_a));
auto bv = dm.getDependent(builder, Registry::get(_b));
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());
auto binop = builder.create<::imex::ndarray::EWBinOp>(
loc, outTyp, builder.getI32IntegerAttr(sharpy2mlir(_op)), av, bv);
// insertsliceop has no return value, so we just create the op...
auto zero = ::imex::createIndex(loc, builder, 0);
auto one = ::imex::createIndex(loc, builder, 1);
auto dyn = ::imex::createIndex(loc, builder, ::mlir::ShapedType::kDynamic);
::mlir::SmallVector<::mlir::Value> offs(rank(), zero);
::mlir::SmallVector<::mlir::Value> szs(rank(), dyn);
::mlir::SmallVector<::mlir::Value> strds(rank(), one);
(void)builder.create<::imex::ndarray::InsertSliceOp>(loc, av, binop, offs,
szs, strds);
// ... and use av as to later create the ndarray
dm.addVal(this->guid(), av,
[this](uint64_t rank, void *l_allocated, void *l_aligned,
intptr_t l_offset, const intptr_t *l_sizes,
const intptr_t *l_strides, void *o_allocated,
void *o_aligned, intptr_t o_offset,
const intptr_t *o_sizes, const intptr_t *o_strides,
void *r_allocated, void *r_aligned, intptr_t r_offset,
const intptr_t *r_sizes, const intptr_t *r_strides,
std::vector<int64_t> &&loffs) {
this->set_value(Registry::get(this->_a).get());
});
return false;
}
FactoryId factory() const override { return F_IEWBINOP; }
template <typename S> void serialize(S &ser) {
ser.template value<sizeof(_a)>(_a);
ser.template value<sizeof(_b)>(_b);
ser.template value<sizeof(_op)>(_op);
}
};
FutureArray *IEWBinOp::op(IEWBinOpId op, FutureArray &a, const py::object &b) {
auto bb =
Creator::mk_future(b, a.get().device(), a.get().team(), a.get().dtype());
auto res =
new FutureArray(defer<DeferredIEWBinOp>(op, a.get(), bb.first->get()));
if (bb.second)
delete bb.first;
return res;
}
FACTORY_INIT(DeferredIEWBinOp, F_IEWBINOP);
} // namespace SHARPY