Skip to content

Commit 35dc6c8

Browse files
authored
Merge pull request #84 from InfiniTensor/add_max_min_kernel
add max/min kernel
2 parents 6630866 + e237349 commit 35dc6c8

File tree

13 files changed

+571
-35
lines changed

13 files changed

+571
-35
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef KERNEL_SELECT_H
2+
#define KERNEL_SELECT_H
3+
4+
#include "../collector.h"
5+
6+
namespace refactor::kernel {
7+
8+
enum class SelectType {
9+
Max,
10+
Min,
11+
};
12+
13+
std::string_view opName(SelectType type);
14+
15+
struct SelectCollector final : public InfoCollector {
16+
SelectType selectType;
17+
18+
SelectCollector(decltype(_target), SelectType) noexcept;
19+
20+
std::vector<KernelBox>
21+
filter(TensorRefs inputs, TensorRefs outputs) const final;
22+
};
23+
24+
}// namespace refactor::kernel
25+
26+
#endif// KERNEL_SELECT_H

src/04kernel/include/kernel/selector.h

-16
This file was deleted.

src/04kernel/src/attributes/broadcaster.cc

+11-7
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,17 @@ namespace refactor::kernel {
8686
}()) {}
8787

8888
void Broadcaster::locate(dim_t k, dim_t ans[]) const noexcept {
89-
long rem = k;
90-
std::fill_n(ans, inputsCount, 0);
91-
for (auto i : range0_(strides.size() / (inputsCount + 1))) {
92-
auto dim = strides.data() + (inputsCount + 1) * i;
93-
auto div = std::div(rem, dim[inputsCount]);
94-
for (auto j : range0_(inputsCount)) { ans[j] += dim[j] * div.quot; }
95-
rem = div.rem;
89+
if (!needBroadcast()) {
90+
std::fill_n(ans, inputsCount, k);
91+
} else {
92+
long rem = k;
93+
std::fill_n(ans, inputsCount, 0);
94+
for (auto i : range0_(strides.size() / (inputsCount + 1))) {
95+
auto dim = strides.data() + (inputsCount + 1) * i;
96+
auto div = std::div(rem, dim[inputsCount]);
97+
for (auto j : range0_(inputsCount)) { ans[j] += dim[j] * div.quot; }
98+
rem = div.rem;
99+
}
96100
}
97101
}
98102

src/04kernel/src/collectors/select.cc

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "kernel/collectors/select.h"
2+
#include "../kernels/select/cpu_kernel.hh"
3+
#include "../kernels/select/cuda_kernel.hh"
4+
5+
namespace refactor::kernel {
6+
7+
#define REGISTER(T) \
8+
if (auto ptr = T::build(selectType, inputs); ptr) { \
9+
ans.emplace_back(std::move(ptr)); \
10+
}
11+
12+
#define CASE(OP) \
13+
case SelectType::OP: \
14+
return #OP
15+
16+
std::string_view opName(SelectType type) {
17+
switch (type) {
18+
CASE(Max);
19+
CASE(Min);
20+
default:
21+
UNREACHABLE();
22+
}
23+
}
24+
25+
SelectCollector::SelectCollector(decltype(_target) target, SelectType type) noexcept
26+
: InfoCollector(target), selectType(type) {}
27+
28+
std::vector<KernelBox>
29+
SelectCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
30+
std::vector<KernelBox> ans;
31+
switch (_target) {
32+
case decltype(_target)::Cpu:
33+
REGISTER(SelectCpu)
34+
break;
35+
case decltype(_target)::Nvidia:
36+
REGISTER(SelectCuda)
37+
break;
38+
default:
39+
UNREACHABLEX(void, "Unknown target");
40+
}
41+
return ans;
42+
}
43+
44+
}// namespace refactor::kernel

src/04kernel/src/kernels/concat/cuda_kernel.cc

-7
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,6 @@ extern "C" __global__ void kernel(
8181
}
8282
auto segments = ss.str();
8383

84-
ss.str("");
85-
for (auto i : range0_(inputCount)) {
86-
ss << std::endl
87-
<< " reinterpret_cast<char const *>(inputs[" << i << "]), ";
88-
}
89-
auto castInputs = ss.str();
90-
9184
ss.str("");
9285
ss << "Concat_" << info.blockCount << ',' << unit;
9386
for (auto seg : info.segments) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#include "cpu_kernel.hh"
2+
#include <execution>
3+
4+
namespace refactor::kernel {
5+
using K = SelectCpu;
6+
using DT = DataType;
7+
8+
K::SelectCpu(
9+
decltype(dataType) dataType_,
10+
decltype(selectType) selectType_,
11+
decltype(broadcaster) broadcaster_,
12+
decltype(inputsNum) inputsNum_) noexcept
13+
: dataType(dataType_),
14+
selectType(selectType_),
15+
broadcaster(broadcaster_),
16+
inputsNum(inputsNum_) {}
17+
18+
auto K::build(SelectType selectType_, TensorRefs inputs_) noexcept -> KernelBox {
19+
auto const &x = inputs_[0].get();
20+
return x.dataType.isCpuNumberic()
21+
? std::make_unique<K>(x.dataType, selectType_, Broadcaster(inputs_), inputs_.size())
22+
: nullptr;
23+
}
24+
auto K::typeId() noexcept -> size_t {
25+
static uint8_t ID = 1;
26+
return reinterpret_cast<size_t>(&ID);
27+
}
28+
29+
auto K::kernelTypeId() const noexcept -> size_t {
30+
return typeId();
31+
}
32+
auto K::description() const noexcept -> std::string_view {
33+
return "Performing select operation on generic cpu";
34+
}
35+
36+
template<class T>
37+
auto lowerTyped(SelectType selectType, Broadcaster broadcaster, size_t inputsNum) noexcept -> RoutineWorkspace {
38+
using namespace runtime;
39+
40+
T(*op)
41+
(T const a, T const b);
42+
switch (selectType) {
43+
case SelectType::Max:
44+
op = [](T const a, T const b) { return std::max(a, b); };
45+
break;
46+
case SelectType::Min:
47+
op = [](T const a, T const b) { return std::min(a, b); };
48+
break;
49+
default:
50+
UNREACHABLE();
51+
}
52+
53+
return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
54+
auto output = reinterpret_cast<T *>(outputs[0]);
55+
for (auto i : range0_(broadcaster.outputsCount)) {
56+
std::vector<dim_t> ans(broadcaster.inputsCount);
57+
broadcaster.locate(i, ans.data());
58+
for (auto inputIdx : range0_(inputsNum)) {
59+
auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
60+
if (inputIdx == 0) {
61+
output[i] = input[ans[inputIdx]];
62+
} else {
63+
output[i] = op(output[i], input[ans[inputIdx]]);
64+
}
65+
}
66+
}
67+
};
68+
}
69+
70+
auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
71+
#define CASE(DT) \
72+
case DataType::DT: \
73+
return lowerTyped<primitive<DataType::DT>::type>(selectType, broadcaster, inputsNum)
74+
75+
switch (dataType) {
76+
CASE(F32);
77+
CASE(U8);
78+
CASE(I8);
79+
CASE(U16);
80+
CASE(I16);
81+
CASE(I32);
82+
CASE(I64);
83+
CASE(F64);
84+
CASE(U32);
85+
CASE(U64);
86+
default:
87+
UNREACHABLE();
88+
}
89+
}
90+
91+
}// namespace refactor::kernel
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef KERNEL_SELECT_CPU_KERNEL_HH
2+
#define KERNEL_SELECT_CPU_KERNEL_HH
3+
4+
#include "kernel/attributes/broadcaster.h"
5+
#include "kernel/collectors/select.h"
6+
#include "kernel/kernel.h"
7+
#include "kernel/tensor.h"
8+
9+
namespace refactor::kernel {
10+
11+
struct SelectCpu final : public Kernel {
12+
DataType dataType;
13+
SelectType selectType;
14+
Broadcaster broadcaster;
15+
size_t inputsNum;
16+
17+
SelectCpu(decltype(dataType), decltype(selectType), decltype(broadcaster), decltype(inputsNum)) noexcept;
18+
19+
static KernelBox build(SelectType, TensorRefs) noexcept;
20+
static size_t typeId() noexcept;
21+
22+
size_t kernelTypeId() const noexcept final;
23+
std::string_view description() const noexcept final;
24+
RoutineWorkspace lower(Resources &) const noexcept final;
25+
};
26+
27+
}// namespace refactor::kernel
28+
29+
#endif// KERNEL_Select_CPU_KERNEL_HH

0 commit comments

Comments
 (0)