Skip to content

Commit d55c074

Browse files
committed
feat(kernel): 实现 slice info 整合
Signed-off-by: YdrMaster <[email protected]>
1 parent 58adc3f commit d55c074

File tree

5 files changed

+88
-2
lines changed

5 files changed

+88
-2
lines changed

src/00common/include/common.h

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ namespace refactor {
1717
using ddim_t = int16_t;
1818
/// @brief 用于表示形状的数值。
1919
using dim_t = uint32_t;
20+
/// @brief 用于表示带符号的形状的数值。
21+
using sdim_t = int32_t;
2022
/// @brief 用于表示对象的数量。
2123
using count_t = uint32_t;
2224

src/04kernel/include/kernel/attributes/slice_info.h

+16-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,27 @@
44
#include "../tensor.h"
55

66
namespace refactor::kernel {
7+
namespace slice {
8+
struct Dim {
9+
int64_t start, step, length;
10+
};
11+
}// namespace slice
12+
13+
using Dimensions = std::vector<slice::Dim>;
714

815
/// @brief 优化用于计算的 Slice 描述。
916
struct SliceInfo {
1017
struct Dim {
11-
int64_t start, step, length;
18+
dim_t countStride, sizeStart;
19+
sdim_t sizeStride;
20+
21+
bool operator==(Dim const &) const noexcept;
22+
bool operator!=(Dim const &) const noexcept;
1223
};
24+
std::vector<Dim> dims;
25+
dim_t blockSize;
26+
27+
SliceInfo(Dimensions const &, Tensor const &) noexcept;
1328
};
1429

1530
}// namespace refactor::kernel

src/04kernel/include/kernel/collectors/slice.h

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include "../target.h"
77

88
namespace refactor::kernel {
9-
using Dimensions = std::vector<SliceInfo::Dim>;
109

1110
struct SliceCollector final : public InfoCollector {
1211
Target target;

src/04kernel/src/attributes/slice_info.cc

+42
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,46 @@
22

33
namespace refactor::kernel {
44

5+
bool SliceInfo::Dim::operator==(Dim const &rhs) const noexcept {
6+
return countStride == rhs.countStride &&
7+
sizeStart == rhs.sizeStart &&
8+
sizeStride == rhs.sizeStride;
9+
}
10+
bool SliceInfo::Dim::operator!=(Dim const &rhs) const noexcept {
11+
return !operator==(rhs);
12+
}
13+
14+
SliceInfo::SliceInfo(Dimensions const &dims_, Tensor const &input) noexcept
15+
: blockSize(input.dataType.size()), dims(1) {
16+
ASSERT(dims_.size() == input.rank(), "Unreachable");
17+
18+
auto continuous = true;
19+
auto stride = blockSize;
20+
dims[0] = {1, 0, static_cast<sdim_t>(stride)};
21+
for (auto i : range0_(input.rank()).rev()) {
22+
auto l = input.shape[i];
23+
auto const &d = dims_[i];
24+
if (continuous && d.step == 1) {
25+
auto &it = dims.back();
26+
it.countStride *= d.length;
27+
it.sizeStart = d.start * stride;
28+
it.sizeStride *= l;
29+
} else {
30+
dims.push_back(Dim{
31+
static_cast<dim_t>(dims.back().countStride * d.length),
32+
static_cast<dim_t>(d.start * stride),
33+
static_cast<sdim_t>(d.step * stride),
34+
});
35+
}
36+
continuous = d.length == l;
37+
stride *= l;
38+
}
39+
auto blockCount = dims[0].countStride;
40+
blockSize *= blockCount;
41+
for (auto &d : dims) {
42+
d.countStride /= blockCount;
43+
}
44+
std::reverse(dims.begin(), dims.end());
45+
}
46+
547
}// namespace refactor::kernel
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include "kernel/attributes/slice_info.h"
2+
#include <gtest/gtest.h>
3+
4+
using namespace refactor;
5+
using namespace kernel;
6+
7+
TEST(kernel, SliceInfo) {
8+
auto input = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3});
9+
Dimensions dims{
10+
{5, -2, 3},// 7 -> {5, 3, 1} -> {108, 900, -360}
11+
{2, 3, 2}, // 6 -> {2, 5} -> { 36, 60, 90}
12+
{1, 1, 3}, // 5 -> {1, 2, 3} -> { 18, 6, 30}
13+
{0, 1, 1}, // 1 -> {0}
14+
{0, 1, 2}, // 2 -> {0, 1}
15+
{0, 1, 3}, // 3 -> {0, 1, 2}
16+
};
17+
SliceInfo info(dims, *input);
18+
EXPECT_EQ(info.blockSize, 72);
19+
EXPECT_EQ(info.dims,
20+
// clang-format off
21+
(decltype(info.dims){
22+
{108 / 18, 900 * 4, -360 * 4},
23+
{ 36 / 18, 60 * 4, 90 * 4},
24+
{ 18 / 18, 6 * 4, 30 * 4},
25+
})
26+
// clang-format on
27+
);
28+
}

0 commit comments

Comments
 (0)