Skip to content

Commit 41abc9d

Browse files
committed
wip
1 parent 361dc2b commit 41abc9d

File tree

9 files changed

+103
-3
lines changed

9 files changed

+103
-3
lines changed

Diff for: CMakeLists.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,10 @@ include_directories(
149149
${PROJECT_SOURCE_DIR}/third_party/bitsery/include
150150
${MPI_INCLUDE_PATH}
151151
${pybind11_INCLUDE_DIRS}
152+
${LLVM_INCLUDE_DIRS}
152153
${MLIR_INCLUDE_DIRS}
153-
${IMEX_INCLUDE_DIRS})
154+
${IMEX_INCLUDE_DIRS}
155+
"/export/users/yzhao/work/sharpy_ws/builds/debug/tools/Imex/include")
154156

155157
if (CMAKE_SYSTEM_NAME STREQUAL Linux)
156158
target_link_options(_sharpy PRIVATE "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/export.txt")

Diff for: setup.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import multiprocessing
12
import os
23
import pathlib
34

@@ -44,7 +45,10 @@ def build_cmake(self, ext):
4445
os.chdir(str(build_temp))
4546
self.spawn(["cmake", str(cwd)] + cmake_args)
4647
if not self.dry_run:
47-
self.spawn(["cmake", "--build", ".", "-j5"] + build_args)
48+
self.spawn(
49+
["cmake", "--build", ".", f"-j{multiprocessing.cpu_count()}"]
50+
+ build_args
51+
)
4852
# Troubleshooting: if fail on line above then delete all possible
4953
# temporary CMake files including "CMakeCache.txt" in top level dir.
5054
os.chdir(str(cwd))

Diff for: sharpy/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def _validate_device(device):
130130
exec(
131131
f"{func} = lambda this, shape, cp=None: ndarray(_csp.ManipOp.reshape(this._t, shape, cp))"
132132
)
133+
elif func == "permute_dims":
134+
exec(
135+
f"{func} = lambda this, axes: ndarray(_csp.ManipOp.permute_dims(this._t, axes))"
136+
)
133137

134138
for func in api.api_categories["ReduceOp"]:
135139
FUNC = func.upper()

Diff for: sharpy/array_api.py

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
"roll", # (x, /, shift, *, axis=None)
180180
"squeeze", # (x, /, axis)
181181
"stack", # (arrays, /, *, axis=0)
182+
"permute_dims", # (x: array, /, axes: Tuple[int, ...]) → array
182183
],
183184
"LinAlgOp": [
184185
"matmul", # (x1, x2, /)

Diff for: src/ManipOp.cpp

+76
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,57 @@ struct DeferredToDevice : public Deferred {
205205
}
206206
};
207207

208+
struct DeferredPermuteDims : public Deferred {
209+
id_type _array;
210+
shape_type _axes;
211+
212+
DeferredPermuteDims() = default;
213+
DeferredPermuteDims(const array_i::future_type &array,
214+
const shape_type &shape, const shape_type &axes)
215+
: Deferred(array.dtype(), shape, array.device(), array.team()),
216+
_array(array.guid()), _axes(axes) {}
217+
218+
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
219+
jit::DepManager &dm) override {
220+
auto arrayValue = dm.getDependent(builder, Registry::get(_array));
221+
222+
auto axesAttr = builder.getI64ArrayAttr(_axes);
223+
224+
auto aTyp =
225+
::mlir::cast<::imex::ndarray::NDArrayType>(arrayValue.getType());
226+
auto outTyp = imex::dist::cloneWithShape(aTyp, shape());
227+
228+
auto op = builder.create<::imex::ndarray::PermuteDimsOp>(
229+
loc, outTyp, arrayValue, axesAttr);
230+
231+
dm.addVal(
232+
this->guid(), op,
233+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
234+
intptr_t l_offset, const intptr_t *l_sizes,
235+
const intptr_t *l_strides, void *o_allocated, void *o_aligned,
236+
intptr_t o_offset, const intptr_t *o_sizes,
237+
const intptr_t *o_strides, void *r_allocated, void *r_aligned,
238+
intptr_t r_offset, const intptr_t *r_sizes,
239+
const intptr_t *r_strides, std::vector<int64_t> &&loffs) {
240+
auto t = mk_tnsr(this->guid(), _dtype, this->shape(), this->device(),
241+
this->team(), l_allocated, l_aligned, l_offset,
242+
l_sizes, l_strides, o_allocated, o_aligned, o_offset,
243+
o_sizes, o_strides, r_allocated, r_aligned, r_offset,
244+
r_sizes, r_strides, std::move(loffs));
245+
this->set_value(std::move(t));
246+
});
247+
248+
return false;
249+
}
250+
251+
FactoryId factory() const override { return F_PERMUTEDIMS; }
252+
253+
template <typename S> void serialize(S &ser) {
254+
ser.template value<sizeof(_array)>(_array);
255+
// ser.template value<sizeof(_axes)>(_axes);
256+
}
257+
};
258+
208259
FutureArray *ManipOp::reshape(const FutureArray &a, const shape_type &shape,
209260
const py::object &copy) {
210261
auto doCopy = copy.is_none()
@@ -229,7 +280,32 @@ FutureArray *ManipOp::to_device(const FutureArray &a,
229280
return new FutureArray(defer<DeferredToDevice>(a.get(), device));
230281
}
231282

283+
FutureArray *ManipOp::permute_dims(const FutureArray &array,
284+
const shape_type &axes) {
285+
auto shape = array.get().shape();
286+
287+
// verifyPermuteArray
288+
if (shape.size() != axes.size()) {
289+
throw std::invalid_argument("axes must have the same length as the shape");
290+
}
291+
for (auto i = 0ul; i < shape.size(); ++i) {
292+
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
293+
throw std::invalid_argument("axes must contain all dimensions");
294+
}
295+
}
296+
297+
auto permutedShape = shape_type(shape.size());
298+
for (auto i = 0ul; i < shape.size(); ++i) {
299+
permutedShape[i] = shape[axes[i]];
300+
}
301+
302+
return new FutureArray(
303+
defer<DeferredPermuteDims>(array.get(), permutedShape, axes));
304+
}
305+
232306
FACTORY_INIT(DeferredReshape, F_RESHAPE);
233307
FACTORY_INIT(DeferredAsType, F_ASTYPE);
234308
FACTORY_INIT(DeferredToDevice, F_TODEVICE);
309+
FACTORY_INIT(DeferredPermuteDims, F_PERMUTEDIMS);
310+
235311
} // namespace SHARPY

Diff for: src/_sharpy.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ PYBIND11_MODULE(_sharpy, m) {
196196
py::class_<IEWBinOp>(m, "IEWBinOp").def("op", &IEWBinOp::op);
197197
py::class_<EWBinOp>(m, "EWBinOp").def("op", &EWBinOp::op);
198198
py::class_<ReduceOp>(m, "ReduceOp").def("op", &ReduceOp::op);
199-
py::class_<ManipOp>(m, "ManipOp").def("reshape", &ManipOp::reshape);
199+
py::class_<ManipOp>(m, "ManipOp")
200+
.def("reshape", &ManipOp::reshape)
201+
.def("permute_dims", &ManipOp::permute_dims);
200202
py::class_<LinAlgOp>(m, "LinAlgOp").def("vecdot", &LinAlgOp::vecdot);
201203

202204
py::class_<FutureArray>(m, "SHARPYFuture")

Diff for: src/include/sharpy/CppTypes.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ enum FactoryId : int {
339339
F_REDUCEOP,
340340
F_REPLICATE,
341341
F_RESHAPE,
342+
F_PERMUTEDIMS,
342343
F_SERVICE,
343344
F_SETITEM,
344345
F_ASTYPE,

Diff for: src/include/sharpy/ManipOp.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,8 @@ struct ManipOp {
2020

2121
static FutureArray *to_device(const FutureArray &a,
2222
const std::string &device);
23+
24+
static FutureArray *permute_dims(const FutureArray &array,
25+
const shape_type &axes);
2326
};
2427
} // namespace SHARPY

Diff for: test/test_manip.py

+7
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,10 @@ def test_todevice_host2gpu(self):
8686
a = sp.arange(0, 8, 1, sp.int32)
8787
b = a.to_device(device="GPU")
8888
assert numpy.allclose(sp.to_numpy(b), [0, 1, 2, 3, 4, 5, 6, 7])
89+
90+
def test_permute_dims(self):
91+
def doit(aapi, **kwargs):
92+
a = aapi.arange(0, 12 * 11, 1, aapi.int32, **kwargs)
93+
return aapi.permite_dims(a, [1, 0])
94+
95+
assert runAndCompare(doit)

0 commit comments

Comments
 (0)