Skip to content

Commit 8a019d1

Browse files
Merge pull request #1243 from IntelPython/elementwise-floor-ceil-trunc
impl_ele_funcs_floor_ceil_trunc
2 parents 0eb2fdf + a2c0aee commit 8a019d1

File tree

9 files changed

+1051
-17
lines changed

9 files changed

+1051
-17
lines changed

Diff for: dpctl/tensor/__init__.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,14 @@
9494
from ._elementwise_funcs import (
9595
abs,
9696
add,
97+
ceil,
9798
conj,
9899
cos,
99100
divide,
100101
equal,
101102
exp,
102103
expm1,
104+
floor,
103105
floor_divide,
104106
greater,
105107
greater_equal,
@@ -128,6 +130,7 @@
128130
sqrt,
129131
square,
130132
subtract,
133+
trunc,
131134
)
132135
from ._reduction import sum
133136

@@ -208,16 +211,21 @@
208211
"inf",
209212
"abs",
210213
"add",
214+
"ceil",
211215
"conj",
212216
"cos",
217+
"divide",
218+
"equal",
213219
"exp",
214220
"expm1",
221+
"floor",
222+
"floor_divide",
215223
"greater",
216224
"greater_equal",
217225
"imag",
226+
"isfinite",
218227
"isinf",
219228
"isnan",
220-
"isfinite",
221229
"less",
222230
"less_equal",
223231
"log",
@@ -228,19 +236,17 @@
228236
"log1p",
229237
"log2",
230238
"log10",
239+
"multiply",
231240
"negative",
241+
"not_equal",
232242
"positive",
243+
"pow",
233244
"proj",
234245
"real",
235246
"sin",
236247
"sqrt",
237248
"square",
238-
"divide",
239-
"multiply",
240-
"pow",
241249
"subtract",
242-
"equal",
243-
"not_equal",
244250
"sum",
245-
"floor_divide",
251+
"trunc",
246252
]

Diff for: dpctl/tensor/_elementwise_common.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ def __call__(self, x, out=None, order="K"):
5858
x.dtype, self.result_type_resolver_fn_, x.sycl_device
5959
)
6060
if res_dt is None:
61-
raise RuntimeError
61+
raise TypeError(
62+
f"function '{self.name_}' does not support input type "
63+
f"({x.dtype}), "
64+
"and the input could not be safely coerced to any "
65+
"supported types according to the casting rule ''safe''."
66+
)
6267

6368
orig_out = out
6469
if out is not None:

Diff for: dpctl/tensor/_elementwise_funcs.py

+74-3
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,30 @@
114114
# FIXME: implement B07
115115

116116
# U09: ==== CEIL (x)
117-
# FIXME: implement U09
117+
_ceil_docstring = """
118+
ceil(x, out=None, order='K')
119+
120+
Returns the ceiling for each element `x_i` for input array `x`.
121+
The ceil of the scalar `x` is the smallest integer `i`, such that `i >= x`.
122+
123+
Args:
124+
x (usm_ndarray):
125+
Input array, expected to have numeric data type.
126+
out ({None, usm_ndarray}, optional):
127+
Output array to populate.
128+
Array have the correct shape and the expected data type.
129+
order ("C","F","A","K", optional):
130+
Memory layout of the newly output array, if parameter `out` is `None`.
131+
Default: "K".
132+
Returns:
133+
usm_narray:
134+
An array containing the element-wise ceiling of input array.
135+
The returned array has the same data type as `x`.
136+
"""
137+
138+
ceil = UnaryElementwiseFunc(
139+
"ceil", ti._ceil_result_type, ti._ceil, _ceil_docstring
140+
)
118141

119142
# U10: ==== CONJ (x)
120143
_conj_docstring = """
@@ -271,7 +294,30 @@
271294
)
272295

273296
# U15: ==== FLOOR (x)
274-
# FIXME: implement U15
297+
_floor_docstring = """
298+
floor(x, out=None, order='K')
299+
300+
Returns the floor for each element `x_i` for input array `x`.
301+
The floor of the scalar `x` is the largest integer `i`, such that `i <= x`.
302+
303+
Args:
304+
x (usm_ndarray):
305+
Input array, expected to have numeric data type.
306+
out ({None, usm_ndarray}, optional):
307+
Output array to populate.
308+
Array have the correct shape and the expected data type.
309+
order ("C","F","A","K", optional):
310+
Memory layout of the newly output array, if parameter `out` is `None`.
311+
Default: "K".
312+
Returns:
313+
usm_narray:
314+
An array containing the element-wise floor of input array.
315+
The returned array has the same data type as `x`.
316+
"""
317+
318+
floor = UnaryElementwiseFunc(
319+
"floor", ti._floor_result_type, ti._floor, _floor_docstring
320+
)
275321

276322
# B10: ==== FLOOR_DIVIDE (x1, x2)
277323
_floor_divide_docstring_ = """
@@ -1031,4 +1077,29 @@
10311077
# FIXME: implement U35
10321078

10331079
# U36: ==== TRUNC (x)
1034-
# FIXME: implement U36
1080+
_trunc_docstring = """
1081+
trunc(x, out=None, order='K')
1082+
1083+
Returns the truncated value for each element `x_i` for input array `x`.
1084+
The truncated value of the scalar `x` is the nearest integer i which is
1085+
closer to zero than `x` is. In short, the fractional part of the
1086+
signed number `x` is discarded.
1087+
1088+
Args:
1089+
x (usm_ndarray):
1090+
Input array, expected to have numeric data type.
1091+
out ({None, usm_ndarray}, optional):
1092+
Output array to populate.
1093+
Array have the correct shape and the expected data type.
1094+
order ("C","F","A","K", optional):
1095+
Memory layout of the newly output array, if parameter `out` is `None`.
1096+
Default: "K".
1097+
Returns:
1098+
usm_narray:
1099+
An array containing the element-wise truncated value of input array.
1100+
The returned array has the same data type as `x`.
1101+
"""
1102+
1103+
trunc = UnaryElementwiseFunc(
1104+
"trunc", ti._trunc_result_type, ti._trunc, _trunc_docstring
1105+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
//=== ceil.hpp - Unary function CEIL ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of CEIL(x) function.
23+
//===---------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <cmath>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "kernels/elementwise_functions/common.hpp"
33+
34+
#include "utils/offset_utils.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
#include "utils/type_utils.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace ceil
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
51+
using dpctl::tensor::type_utils::is_complex;
52+
53+
template <typename argT, typename resT> struct CeilFunctor
54+
{
55+
56+
// is function constant for given argT
57+
using is_constant = typename std::false_type;
58+
// constant value, if constant
59+
// constexpr resT constant_value = resT{};
60+
// is function defined for sycl::vec
61+
using supports_vec = typename std::false_type;
62+
// do both argTy and resTy support sugroup store/load operation
63+
using supports_sg_loadstore = typename std::negation<
64+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
65+
66+
resT operator()(const argT &in)
67+
{
68+
if constexpr (std::is_integral_v<argT>) {
69+
return in;
70+
}
71+
else {
72+
if (in == 0) {
73+
return in;
74+
}
75+
return std::ceil(in);
76+
}
77+
}
78+
};
79+
80+
template <typename argTy,
81+
typename resTy = argTy,
82+
unsigned int vec_sz = 4,
83+
unsigned int n_vecs = 2>
84+
using CeilContigFunctor = elementwise_common::
85+
UnaryContigFunctor<argTy, resTy, CeilFunctor<argTy, resTy>, vec_sz, n_vecs>;
86+
87+
template <typename argTy, typename resTy, typename IndexerT>
88+
using CeilStridedFunctor = elementwise_common::
89+
UnaryStridedFunctor<argTy, resTy, IndexerT, CeilFunctor<argTy, resTy>>;
90+
91+
template <typename T> struct CeilOutputType
92+
{
93+
using value_type = typename std::disjunction< // disjunction is C++17
94+
// feature, supported by DPC++
95+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
96+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
97+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
98+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
99+
td_ns::TypeMapResultEntry<T, std::int8_t>,
100+
td_ns::TypeMapResultEntry<T, std::int16_t>,
101+
td_ns::TypeMapResultEntry<T, std::int32_t>,
102+
td_ns::TypeMapResultEntry<T, std::int64_t>,
103+
td_ns::TypeMapResultEntry<T, sycl::half>,
104+
td_ns::TypeMapResultEntry<T, float>,
105+
td_ns::TypeMapResultEntry<T, double>,
106+
td_ns::DefaultResultEntry<void>>::result_type;
107+
};
108+
109+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
110+
class ceil_contig_kernel;
111+
112+
template <typename argTy>
113+
sycl::event ceil_contig_impl(sycl::queue exec_q,
114+
size_t nelems,
115+
const char *arg_p,
116+
char *res_p,
117+
const std::vector<sycl::event> &depends = {})
118+
{
119+
return elementwise_common::unary_contig_impl<
120+
argTy, CeilOutputType, CeilContigFunctor, ceil_contig_kernel>(
121+
exec_q, nelems, arg_p, res_p, depends);
122+
}
123+
124+
template <typename fnT, typename T> struct CeilContigFactory
125+
{
126+
fnT get()
127+
{
128+
if constexpr (std::is_same_v<typename CeilOutputType<T>::value_type,
129+
void>) {
130+
fnT fn = nullptr;
131+
return fn;
132+
}
133+
else {
134+
fnT fn = ceil_contig_impl<T>;
135+
return fn;
136+
}
137+
}
138+
};
139+
140+
template <typename fnT, typename T> struct CeilTypeMapFactory
141+
{
142+
/*! @brief get typeid for output type of sycl::ceil(T x) */
143+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
144+
{
145+
using rT = typename CeilOutputType<T>::value_type;
146+
return td_ns::GetTypeid<rT>{}.get();
147+
}
148+
};
149+
150+
template <typename T1, typename T2, typename T3> class ceil_strided_kernel;
151+
152+
template <typename argTy>
153+
sycl::event
154+
ceil_strided_impl(sycl::queue exec_q,
155+
size_t nelems,
156+
int nd,
157+
const py::ssize_t *shape_and_strides,
158+
const char *arg_p,
159+
py::ssize_t arg_offset,
160+
char *res_p,
161+
py::ssize_t res_offset,
162+
const std::vector<sycl::event> &depends,
163+
const std::vector<sycl::event> &additional_depends)
164+
{
165+
return elementwise_common::unary_strided_impl<
166+
argTy, CeilOutputType, CeilStridedFunctor, ceil_strided_kernel>(
167+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
168+
res_offset, depends, additional_depends);
169+
}
170+
171+
template <typename fnT, typename T> struct CeilStridedFactory
172+
{
173+
fnT get()
174+
{
175+
if constexpr (std::is_same_v<typename CeilOutputType<T>::value_type,
176+
void>) {
177+
fnT fn = nullptr;
178+
return fn;
179+
}
180+
else {
181+
fnT fn = ceil_strided_impl<T>;
182+
return fn;
183+
}
184+
}
185+
};
186+
187+
} // namespace ceil
188+
} // namespace kernels
189+
} // namespace tensor
190+
} // namespace dpctl

0 commit comments

Comments
 (0)