Skip to content

Commit 6fc7166

Browse files
Implementation of correlate
1 parent 88911fb commit 6fc7166

17 files changed

+1722
-131
lines changed

dpnp/backend/extensions/statistics/CMakeLists.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@
2626

2727
set(python_module_name _statistics_impl)
2828
set(_module_src
29-
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
3029
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp
3335
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
36+
${CMAKE_CURRENT_SOURCE_DIR}/validation_utils.cpp
3437
)
3538

3639
pybind11_add_module(${python_module_name} MODULE ${_module_src})

dpnp/backend/extensions/statistics/dispatch_table.hpp

+98
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,32 @@ using DTypePair = std::pair<DType, DType>;
9797
using SupportedDTypeList = std::vector<DType>;
9898
using SupportedDTypeList2 = std::vector<DTypePair>;
9999

100+
template <typename FnT,
101+
typename SupportedTypes,
102+
template <typename>
103+
typename Func>
104+
struct TableBuilder
105+
{
106+
template <typename _FnT, typename T>
107+
struct impl
108+
{
109+
static constexpr bool is_defined = one_of_v<T, SupportedTypes>;
110+
111+
_FnT get()
112+
{
113+
if constexpr (is_defined) {
114+
return Func<T>::impl;
115+
}
116+
else {
117+
return nullptr;
118+
}
119+
}
120+
};
121+
122+
using type =
123+
dpctl_td_ns::DispatchVectorBuilder<FnT, impl, dpctl_td_ns::num_types>;
124+
};
125+
100126
template <typename FnT,
101127
typename SupportedTypes,
102128
template <typename, typename>
@@ -124,6 +150,78 @@ struct TableBuilder2
124150
dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
125151
};
126152

153+
template <typename FnT>
154+
class DispatchTable
155+
{
156+
public:
157+
DispatchTable(std::string name) : name(name) {}
158+
159+
template <typename SupportedTypes, template <typename> typename Func>
160+
void populate_dispatch_table()
161+
{
162+
using TBulder = typename TableBuilder<FnT, SupportedTypes, Func>::type;
163+
TBulder builder;
164+
165+
builder.populate_dispatch_vector(table);
166+
populate_supported_types();
167+
}
168+
169+
FnT get_unsafe(int _typenum) const
170+
{
171+
auto array_types = dpctl_td_ns::usm_ndarray_types();
172+
const int type_id = array_types.typenum_to_lookup_id(_typenum);
173+
174+
return table[type_id];
175+
}
176+
177+
FnT get(int _typenum) const
178+
{
179+
auto fn = get_unsafe(_typenum);
180+
181+
if (fn == nullptr) {
182+
auto array_types = dpctl_td_ns::usm_ndarray_types();
183+
const int _type_id = array_types.typenum_to_lookup_id(_typenum);
184+
185+
py::dtype _dtype = dtype_from_typenum(_type_id);
186+
auto _type_pos = std::find(supported_types.begin(),
187+
supported_types.end(), _dtype);
188+
if (_type_pos == supported_types.end()) {
189+
py::str types = py::str(py::cast(supported_types));
190+
py::str dtype = py::str(_dtype);
191+
192+
py::str err_msg =
193+
py::str("'" + name + "' has unsupported type '") + dtype +
194+
py::str("'."
195+
" Supported types are: ") +
196+
types;
197+
198+
throw py::value_error(static_cast<std::string>(err_msg));
199+
}
200+
}
201+
202+
return fn;
203+
}
204+
205+
const SupportedDTypeList &get_all_supported_types() const
206+
{
207+
return supported_types;
208+
}
209+
210+
private:
211+
void populate_supported_types()
212+
{
213+
for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
214+
if (table[i] != nullptr) {
215+
supported_types.emplace_back(dtype_from_typenum(i));
216+
}
217+
}
218+
}
219+
220+
std::string name;
221+
SupportedDTypeList supported_types;
222+
Table<FnT> table;
223+
};
224+
127225
template <typename FnT>
128226
class DispatchTable2
129227
{

dpnp/backend/extensions/statistics/histogram_common.cpp

+40-100
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838

3939
#include "histogram_common.hpp"
4040

41+
#include "validation_utils.hpp"
42+
4143
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
4244
using dpctl::tensor::usm_ndarray;
4345
using dpctl_td_ns::typenum_t;
@@ -46,6 +48,15 @@ namespace statistics
4648
{
4749
using common::CeilDiv;
4850

51+
using validation::array_names;
52+
using validation::array_ptr;
53+
54+
using validation::check_max_dims;
55+
using validation::check_num_dims;
56+
using validation::check_size_at_least;
57+
using validation::common_checks;
58+
using validation::name_of;
59+
4960
namespace histogram
5061
{
5162

@@ -55,11 +66,9 @@ void validate(const usm_ndarray &sample,
5566
const usm_ndarray &histogram)
5667
{
5768
auto exec_q = sample.get_queue();
58-
using array_ptr = const usm_ndarray *;
5969

6070
std::vector<array_ptr> arrays{&sample, &histogram};
61-
std::unordered_map<array_ptr, std::string> names = {
62-
{arrays[0], "sample"}, {arrays[1], "histogram"}};
71+
array_names names = {{arrays[0], "sample"}, {arrays[1], "histogram"}};
6372

6473
array_ptr bins_ptr = nullptr;
6574

@@ -77,117 +86,48 @@ void validate(const usm_ndarray &sample,
7786
names.insert({weights_ptr, "weights"});
7887
}
7988

80-
auto get_name = [&](const array_ptr &arr) {
81-
auto name_it = names.find(arr);
82-
assert(name_it != names.end());
83-
84-
return "'" + name_it->second + "'";
85-
};
86-
87-
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(histogram);
88-
89-
auto unequal_queue =
90-
std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
91-
return arr->get_queue() != exec_q;
92-
});
93-
94-
if (unequal_queue != arrays.cend()) {
95-
throw py::value_error(
96-
get_name(*unequal_queue) +
97-
" parameter has incompatible queue with parameter " +
98-
get_name(&sample));
99-
}
100-
101-
auto non_contig_array =
102-
std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
103-
return !arr->is_c_contiguous();
104-
});
89+
common_checks({&sample, bins.has_value() ? &bins.value() : nullptr,
90+
weights.has_value() ? &weights.value() : nullptr},
91+
{&histogram}, names);
10592

106-
if (non_contig_array != arrays.cend()) {
107-
throw py::value_error(get_name(*non_contig_array) +
108-
" parameter is not c-contiguos");
109-
}
93+
check_size_at_least(bins_ptr, 2, names);
11094

111-
auto check_overlaping = [&](const array_ptr &first,
112-
const array_ptr &second) {
113-
if (first == nullptr || second == nullptr) {
114-
return;
115-
}
116-
117-
const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
118-
119-
if (overlap(*first, *second)) {
120-
throw py::value_error(get_name(first) +
121-
" has overlapping memory segments with " +
122-
get_name(second));
123-
}
124-
};
125-
126-
check_overlaping(&sample, &histogram);
127-
check_overlaping(bins_ptr, &histogram);
128-
check_overlaping(weights_ptr, &histogram);
129-
130-
if (bins_ptr && bins_ptr->get_size() < 2) {
131-
throw py::value_error(get_name(bins_ptr) +
132-
" parameter must have at least 2 elements");
133-
}
134-
135-
if (histogram.get_size() < 1) {
136-
throw py::value_error(get_name(&histogram) +
137-
" parameter must have at least 1 element");
138-
}
139-
140-
if (histogram.get_ndim() != 1) {
141-
throw py::value_error(get_name(&histogram) +
142-
" parameter must be 1d. Actual " +
143-
std::to_string(histogram.get_ndim()) + "d");
144-
}
95+
check_size_at_least(&histogram, 1, names);
96+
check_num_dims(&histogram, 1, names);
14597

14698
if (weights_ptr) {
147-
if (weights_ptr->get_ndim() != 1) {
148-
throw py::value_error(
149-
get_name(weights_ptr) + " parameter must be 1d. Actual " +
150-
std::to_string(weights_ptr->get_ndim()) + "d");
151-
}
99+
check_num_dims(weights_ptr, 1, names);
152100

153101
auto sample_size = sample.get_size();
154102
auto weights_size = weights_ptr->get_size();
155103
if (sample.get_size() != weights_ptr->get_size()) {
156-
throw py::value_error(
157-
get_name(&sample) + " size (" + std::to_string(sample_size) +
158-
") and " + get_name(weights_ptr) + " size (" +
159-
std::to_string(weights_size) + ")" + " must match");
104+
throw py::value_error(name_of(&sample, names) + " size (" +
105+
std::to_string(sample_size) + ") and " +
106+
name_of(weights_ptr, names) + " size (" +
107+
std::to_string(weights_size) + ")" +
108+
" must match");
160109
}
161110
}
162111

163-
if (sample.get_ndim() > 2) {
164-
throw py::value_error(
165-
get_name(&sample) +
166-
" parameter must have no more than 2 dimensions. Actual " +
167-
std::to_string(sample.get_ndim()) + "d");
168-
}
112+
check_max_dims(&sample, 2, names);
169113

170114
if (sample.get_ndim() == 1) {
171-
if (bins_ptr != nullptr && bins_ptr->get_ndim() != 1) {
172-
throw py::value_error(get_name(&sample) + " parameter is 1d, but " +
173-
get_name(bins_ptr) + " is " +
174-
std::to_string(bins_ptr->get_ndim()) + "d");
175-
}
115+
check_num_dims(bins_ptr, 1, names);
176116
}
177117
else if (sample.get_ndim() == 2) {
178118
auto sample_count = sample.get_shape(0);
179119
auto expected_dims = sample.get_shape(1);
180120

181121
if (bins_ptr != nullptr && bins_ptr->get_ndim() != expected_dims) {
182-
throw py::value_error(get_name(&sample) + " parameter has shape {" +
183-
std::to_string(sample_count) + "x" +
184-
std::to_string(expected_dims) + "}" +
185-
", so " + get_name(bins_ptr) +
186-
" parameter expected to be " +
187-
std::to_string(expected_dims) +
188-
"d. "
189-
"Actual " +
190-
std::to_string(bins->get_ndim()) + "d");
122+
throw py::value_error(
123+
name_of(&sample, names) + " parameter has shape {" +
124+
std::to_string(sample_count) + "x" +
125+
std::to_string(expected_dims) + "}" + ", so " +
126+
name_of(bins_ptr, names) + " parameter expected to be " +
127+
std::to_string(expected_dims) +
128+
"d. "
129+
"Actual " +
130+
std::to_string(bins->get_ndim()) + "d");
191131
}
192132
}
193133

@@ -199,17 +139,17 @@ void validate(const usm_ndarray &sample,
199139

200140
if (histogram.get_size() != expected_hist_size) {
201141
throw py::value_error(
202-
get_name(&histogram) + " and " + get_name(bins_ptr) +
203-
" shape mismatch. " + get_name(&histogram) +
204-
" expected to have size = " +
142+
name_of(&histogram, names) + " and " +
143+
name_of(bins_ptr, names) + " shape mismatch. " +
144+
name_of(&histogram, names) + " expected to have size = " +
205145
std::to_string(expected_hist_size) + ". Actual " +
206146
std::to_string(histogram.get_size()));
207147
}
208148
}
209149

210150
int64_t max_hist_size = std::numeric_limits<uint32_t>::max() - 1;
211151
if (histogram.get_size() > max_hist_size) {
212-
throw py::value_error(get_name(&histogram) +
152+
throw py::value_error(name_of(&histogram, names) +
213153
" parameter size expected to be less than " +
214154
std::to_string(max_hist_size) + ". Actual " +
215155
std::to_string(histogram.get_size()));
@@ -225,7 +165,7 @@ void validate(const usm_ndarray &sample,
225165
if (!_64bit_atomics) {
226166
auto device_name = device.get_info<sycl::info::device::name>();
227167
throw py::value_error(
228-
get_name(&histogram) +
168+
name_of(&histogram, names) +
229169
" parameter has 64-bit type, but 64-bit atomics " +
230170
" are not supported for " + device_name);
231171
}

0 commit comments

Comments
 (0)