Skip to content

Commit d1c88a6

Browse files
vtavanandgrigorian
authored andcommitted
impl_ele_funcs_floor_ceil_trunc
1 parent 7bfc5a0 commit d1c88a6

File tree

8 files changed

+1030
-15
lines changed

8 files changed

+1030
-15
lines changed

dpctl/tensor/__init__.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,14 @@
9494
from ._elementwise_funcs import (
9595
abs,
9696
add,
97+
ceil,
9798
conj,
9899
cos,
99100
divide,
100101
equal,
101102
exp,
102103
expm1,
104+
floor,
103105
floor_divide,
104106
greater,
105107
greater_equal,
@@ -122,6 +124,7 @@
122124
sin,
123125
sqrt,
124126
subtract,
127+
trunc,
125128
)
126129
from ._reduction import sum
127130

@@ -202,10 +205,15 @@
202205
"inf",
203206
"abs",
204207
"add",
208+
"ceil",
205209
"conj",
206210
"cos",
211+
"divide",
212+
"equal",
207213
"exp",
208214
"expm1",
215+
"floor",
216+
"floor_divide",
209217
"greater",
210218
"greater_equal",
211219
"imag",
@@ -220,15 +228,13 @@
220228
"logical_or",
221229
"logical_xor",
222230
"log1p",
231+
"multiply",
232+
"not_equal",
223233
"proj",
224234
"real",
225235
"sin",
226236
"sqrt",
227-
"divide",
228-
"multiply",
229237
"subtract",
230-
"equal",
231-
"not_equal",
232238
"sum",
233-
"floor_divide",
239+
"trunc",
234240
]

dpctl/tensor/_elementwise_common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,12 @@ def __call__(self, x, out=None, order="K"):
8181
x.dtype, self.result_type_resolver_fn_, x.sycl_device
8282
)
8383
if res_dt is None:
84-
raise RuntimeError
84+
raise TypeError(
85+
f"function '{self.name_}' does not support input type "
86+
f"({x.dtype}), "
87+
"and the input could not be safely coerced to any "
88+
"supported types according to the casting rule ''safe''."
89+
)
8590
exec_q = x.sycl_queue
8691
if buf_dt is None:
8792
if out is None:

dpctl/tensor/_elementwise_funcs.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,31 @@
114114
# FIXME: implement B07
115115

116116
# U09: ==== CEIL (x)
117-
# FIXME: implement U09
117+
_ceil_docstring = """
118+
ceil(x, out=None, order='K')
119+
120+
Returns the ceiling for each element `x_i` for input array `x`.
121+
The ceil of the scalar `x` is the smallest integer `i`, such that `i >= x`.
122+
123+
Args:
124+
x (usm_ndarray):
125+
Input array, expected to have numeric data type.
126+
out ({None, usm_ndarray}, optional):
127+
Output array to populate.
128+
Array have the correct shape and the expected data type.
129+
order ("C","F","A","K", optional):
130+
Memory layout of the newly output array, if parameter `out` is `None`.
131+
Default: "K".
132+
Returns:
133+
usm_narray:
134+
An array containing the element-wise inverse cosine, in radians
135+
and in the closed interval `[-pi/2, pi/2]`. The data type
136+
of the returned array is determined by the Type Promotion Rules.
137+
"""
138+
139+
ceil = UnaryElementwiseFunc(
140+
"ceil", ti._ceil_result_type, ti._ceil, _ceil_docstring
141+
)
118142

119143
# U10: ==== CONJ (x)
120144
_conj_docstring = """
@@ -271,7 +295,31 @@
271295
)
272296

273297
# U15: ==== FLOOR (x)
274-
# FIXME: implement U15
298+
_floor_docstring = """
299+
floor(x, out=None, order='K')
300+
301+
Returns the floor for each element `x_i` for input array `x`.
302+
The floor of the scalar `x` is the largest integer `i`, such that `i <= x`.
303+
304+
Args:
305+
x (usm_ndarray):
306+
Input array, expected to have numeric data type.
307+
out ({None, usm_ndarray}, optional):
308+
Output array to populate.
309+
Array have the correct shape and the expected data type.
310+
order ("C","F","A","K", optional):
311+
Memory layout of the newly output array, if parameter `out` is `None`.
312+
Default: "K".
313+
Returns:
314+
usm_narray:
315+
An array containing the element-wise floor, in radians
316+
and in the closed interval `[-pi/2, pi/2]`. The data type
317+
of the returned array is determined by the Type Promotion Rules.
318+
"""
319+
320+
floor = UnaryElementwiseFunc(
321+
"floor", ti._floor_result_type, ti._floor, _floor_docstring
322+
)
275323

276324
# B10: ==== FLOOR_DIVIDE (x1, x2)
277325
_floor_divide_docstring_ = """
@@ -905,4 +953,30 @@
905953
# FIXME: implement U35
906954

907955
# U36: ==== TRUNC (x)
908-
# FIXME: implement U36
956+
_trunc_docstring = """
957+
trunc(x, out=None, order='K')
958+
959+
Returns the truncated value for each element `x_i` for input array `x`.
960+
The truncated value of the scalar `x` is the nearest integer i which is
961+
closer to zero than `x` is. In short, the fractional part of the
962+
signed number `x` is discarded.
963+
964+
Args:
965+
x (usm_ndarray):
966+
Input array, expected to have numeric data type.
967+
out ({None, usm_ndarray}, optional):
968+
Output array to populate.
969+
Array have the correct shape and the expected data type.
970+
order ("C","F","A","K", optional):
971+
Memory layout of the newly output array, if parameter `out` is `None`.
972+
Default: "K".
973+
Returns:
974+
usm_narray:
975+
An array containing the element-wise inverse cosine, in radians
976+
and in the closed interval `[-pi/2, pi/2]`. The data type
977+
of the returned array is determined by the Type Promotion Rules.
978+
"""
979+
980+
trunc = UnaryElementwiseFunc(
981+
"trunc", ti._trunc_result_type, ti._trunc, _trunc_docstring
982+
)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
//=== ceil.hpp - Unary function CEIL ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of CEIL(x) function.
23+
//===---------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <cmath>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "kernels/elementwise_functions/common.hpp"
33+
34+
#include "utils/offset_utils.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
#include "utils/type_utils.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace ceil
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
51+
using dpctl::tensor::type_utils::is_complex;
52+
53+
template <typename argT, typename resT> struct CeilFunctor
54+
{
55+
56+
// is function constant for given argT
57+
using is_constant = typename std::false_type;
58+
// constant value, if constant
59+
// constexpr resT constant_value = resT{};
60+
// is function defined for sycl::vec
61+
using supports_vec = typename std::false_type;
62+
// do both argTy and resTy support sugroup store/load operation
63+
using supports_sg_loadstore = typename std::negation<
64+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
65+
66+
resT operator()(const argT &in)
67+
{
68+
if constexpr (std::is_integral_v<argT>) {
69+
return in;
70+
}
71+
else {
72+
return sycl::ceil(in);
73+
}
74+
// return sycl::ceil(in);
75+
}
76+
};
77+
78+
template <typename argTy,
79+
typename resTy = argTy,
80+
unsigned int vec_sz = 4,
81+
unsigned int n_vecs = 2>
82+
using CeilContigFunctor = elementwise_common::
83+
UnaryContigFunctor<argTy, resTy, CeilFunctor<argTy, resTy>, vec_sz, n_vecs>;
84+
85+
template <typename argTy, typename resTy, typename IndexerT>
86+
using CeilStridedFunctor = elementwise_common::
87+
UnaryStridedFunctor<argTy, resTy, IndexerT, CeilFunctor<argTy, resTy>>;
88+
89+
template <typename T> struct CeilOutputType
90+
{
91+
using value_type = typename std::disjunction< // disjunction is C++17
92+
// feature, supported by DPC++
93+
td_ns::TypeMapResultEntry<T, bool, sycl::half>,
94+
td_ns::TypeMapResultEntry<T, sycl::half>,
95+
td_ns::TypeMapResultEntry<T, float>,
96+
td_ns::TypeMapResultEntry<T, double>,
97+
td_ns::DefaultResultEntry<void>>::result_type;
98+
};
99+
100+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
101+
class ceil_contig_kernel;
102+
103+
template <typename argTy>
104+
sycl::event ceil_contig_impl(sycl::queue exec_q,
105+
size_t nelems,
106+
const char *arg_p,
107+
char *res_p,
108+
const std::vector<sycl::event> &depends = {})
109+
{
110+
return elementwise_common::unary_contig_impl<
111+
argTy, CeilOutputType, CeilContigFunctor, ceil_contig_kernel>(
112+
exec_q, nelems, arg_p, res_p, depends);
113+
}
114+
115+
template <typename fnT, typename T> struct CeilContigFactory
116+
{
117+
fnT get()
118+
{
119+
if constexpr (std::is_same_v<typename CeilOutputType<T>::value_type,
120+
void>) {
121+
fnT fn = nullptr;
122+
return fn;
123+
}
124+
else {
125+
fnT fn = ceil_contig_impl<T>;
126+
return fn;
127+
}
128+
}
129+
};
130+
131+
template <typename fnT, typename T> struct CeilTypeMapFactory
132+
{
133+
/*! @brief get typeid for output type of sycl::ceil(T x) */
134+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
135+
{
136+
using rT = typename CeilOutputType<T>::value_type;
137+
return td_ns::GetTypeid<rT>{}.get();
138+
}
139+
};
140+
141+
template <typename T1, typename T2, typename T3> class ceil_strided_kernel;
142+
143+
template <typename argTy>
144+
sycl::event
145+
ceil_strided_impl(sycl::queue exec_q,
146+
size_t nelems,
147+
int nd,
148+
const py::ssize_t *shape_and_strides,
149+
const char *arg_p,
150+
py::ssize_t arg_offset,
151+
char *res_p,
152+
py::ssize_t res_offset,
153+
const std::vector<sycl::event> &depends,
154+
const std::vector<sycl::event> &additional_depends)
155+
{
156+
return elementwise_common::unary_strided_impl<
157+
argTy, CeilOutputType, CeilStridedFunctor, ceil_strided_kernel>(
158+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
159+
res_offset, depends, additional_depends);
160+
}
161+
162+
template <typename fnT, typename T> struct CeilStridedFactory
163+
{
164+
fnT get()
165+
{
166+
if constexpr (std::is_same_v<typename CeilOutputType<T>::value_type,
167+
void>) {
168+
fnT fn = nullptr;
169+
return fn;
170+
}
171+
else {
172+
fnT fn = ceil_strided_impl<T>;
173+
return fn;
174+
}
175+
}
176+
};
177+
178+
} // namespace ceil
179+
} // namespace kernels
180+
} // namespace tensor
181+
} // namespace dpctl

0 commit comments

Comments
 (0)