-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathtest_util.h
74 lines (64 loc) · 1.96 KB
/
test_util.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
#pragma once
#include <array>
#include <vector>
namespace intel {
namespace he {
// Generates a vector of type T with size slots small entries
template <typename T>
inline std::vector<T> generateVector(size_t slots, size_t row_size = 0,
size_t n_rows = 2, size_t n_slots = 4) {
std::vector<T> input(slots, static_cast<T>(0));
if (row_size == 0) {
for (size_t i = 0; i < slots; ++i) {
input[i] = static_cast<T>(i);
}
} else {
for (size_t r = 0; r < n_rows; ++r) {
for (size_t i = 0; i < n_slots; ++i) {
input[i + r * row_size] = static_cast<T>(i + r * n_slots);
}
}
}
return input;
}
template <typename T>
void checkEqual(const std::vector<T>& x, const std::vector<T>& y,
T abs_error = T(0.001)) {
ASSERT_EQ(x.size(), y.size());
for (size_t i = 0; i < x.size(); ++i) {
ASSERT_NEAR(x[i], y[i], 0.001);
}
}
template <typename T>
void checkEqual(const std::vector<std::vector<T>>& x,
const std::vector<std::vector<T>>& y, T abs_error = T(0.001)) {
ASSERT_EQ(x.size(), y.size());
for (size_t i = 0; i < x.size(); ++i) {
checkEqual(x[i], y[i], abs_error);
}
}
template <class CollectionT>
double evaluatePolygon_HornerMethod(double input, const CollectionT& coeff) {
double retval;
auto it = coeff.rbegin();
retval = *it;
for (++it; it != coeff.rend(); ++it) retval = retval * input + *it;
return retval;
}
template <unsigned int sigmoid_degree>
double approxSigmoid(double x);
template <>
inline double approxSigmoid<3>(double x) {
// f3(x) ~= 0.5 + 1.20096(x/8) - 0.81562(x/8)^3
std::array<double, 4> poly = {0.5, 0.15012, 0.0, -0.001593008};
double retval = evaluatePolygon_HornerMethod(x, poly);
if (x < -5.0 || retval < 0.0)
retval = 0.0;
else if (x > 5.0 || retval > 1.0)
retval = 1.0;
return retval;
}
} // namespace he
} // namespace intel