Skip to content

Commit 6117821

Browse files
ENH: update dpnp.fft; MKL integration (#820)
* ENH: update dpnp.fft; MKL integration (comlex to complex, 64/128)
1 parent d797599 commit 6117821

File tree

6 files changed

+160
-32
lines changed

6 files changed

+160
-32
lines changed

dpnp/backend/include/dpnp_iface_fft.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
* @param[in] axis Axis ID to compute by.
7171
* @param[in] input_boundarie Limit number of elements for @ref axis.
7272
* @param[in] inverse Using inverse algorithm.
73+
* @param[in] norm Normalization mode. 0 - backward, 1 - forward.
7374
*/
7475
template <typename _DataType>
7576
INP_DLLEXPORT void dpnp_fft_fft_c(const void* array_in,
@@ -79,6 +80,6 @@ INP_DLLEXPORT void dpnp_fft_fft_c(const void* array_in,
7980
size_t shape_size,
8081
long axis,
8182
long input_boundarie,
82-
size_t inverse);
83-
83+
size_t inverse,
84+
const size_t norm);
8485
#endif // BACKEND_IFACE_FFT_H

dpnp/backend/kernels/dpnp_krnl_fft.cpp

Lines changed: 145 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333

3434
namespace mkl_dft = oneapi::mkl::dft;
3535

36+
typedef mkl_dft::descriptor<mkl_dft::precision::DOUBLE, mkl_dft::domain::COMPLEX> desc_dp_cmplx_t;
37+
typedef mkl_dft::descriptor<mkl_dft::precision::SINGLE, mkl_dft::domain::COMPLEX> desc_sp_cmplx_t;
38+
typedef mkl_dft::descriptor<mkl_dft::precision::DOUBLE, mkl_dft::domain::REAL> desc_dp_real_t;
39+
typedef mkl_dft::descriptor<mkl_dft::precision::SINGLE, mkl_dft::domain::REAL> desc_sp_real_t;
40+
3641
#ifdef _WIN32
3742
#ifndef M_PI // Windows compatibility
3843
#define M_PI 3.14159265358979323846
@@ -43,23 +48,24 @@ template <typename _KernelNameSpecialization1, typename _KernelNameSpecializatio
4348
class dpnp_fft_fft_c_kernel;
4449

4550
template <typename _DataType_input, typename _DataType_output>
46-
void dpnp_fft_fft_c(const void* array1_in,
47-
void* result1,
48-
const long* input_shape,
49-
const long* output_shape,
50-
size_t shape_size,
51-
long axis,
52-
long input_boundarie,
53-
size_t inverse)
51+
void dpnp_fft_fft_sycl_c(const void* array1_in,
52+
void* result1,
53+
const long* input_shape,
54+
const long* output_shape,
55+
size_t shape_size,
56+
const size_t result_size,
57+
const size_t input_size,
58+
long axis,
59+
long input_boundarie,
60+
size_t inverse)
5461
{
55-
const size_t input_size = std::accumulate(input_shape, input_shape + shape_size, 1, std::multiplies<size_t>());
56-
const size_t result_size = std::accumulate(output_shape, output_shape + shape_size, 1, std::multiplies<size_t>());
5762
if (!(input_size && result_size && shape_size))
5863
{
5964
return;
6065
}
6166

6267
cl::sycl::event event;
68+
6369
const double kernel_pi = inverse ? -M_PI : M_PI;
6470

6571
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, input_size);
@@ -148,21 +154,139 @@ void dpnp_fft_fft_c(const void* array1_in,
148154
};
149155

150156
event = DPNP_QUEUE.submit(kernel_func);
157+
event.wait();
158+
159+
dpnp_memory_free_c(input_shape_offsets);
160+
dpnp_memory_free_c(output_shape_offsets);
161+
dpnp_memory_free_c(axis_iterator);
162+
163+
return;
164+
}
151165

152-
#if 0 // keep this code
153-
oneapi::mkl::dft::descriptor<mkl_dft::precision::DOUBLE, mkl_dft::domain::COMPLEX> desc(result_size);
154-
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, static_cast<double>(result_size));
155-
desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE); // enum value from math library C interface
166+
template <typename _DataType_input, typename _DataType_output, typename _Descriptor_type>
167+
void dpnp_fft_fft_mathlib_compute_c(const void* array1_in,
168+
void* result1,
169+
const size_t shape_size,
170+
const size_t result_size,
171+
_Descriptor_type& desc,
172+
const size_t norm)
173+
{
174+
cl::sycl::event event;
175+
176+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, result_size);
177+
DPNPC_ptr_adapter<_DataType_output> result_ptr(result1, result_size);
178+
_DataType_input* array_1 = input1_ptr.get_ptr();
179+
_DataType_output* result = result_ptr.get_ptr();
180+
181+
desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, (1.0 / result_size));
182+
// enum value from math library C interface
183+
// instead of mkl_dft::config_value::NOT_INPLACE
184+
desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
156185
desc.commit(DPNP_QUEUE);
157186

158187
event = mkl_dft::compute_forward(desc, array_1, result);
159-
#endif
160188

161189
event.wait();
162190

163-
dpnp_memory_free_c(input_shape_offsets);
164-
dpnp_memory_free_c(output_shape_offsets);
165-
dpnp_memory_free_c(axis_iterator);
191+
return;
192+
}
193+
194+
// norm: backward - 0, forward is 1
195+
template <typename _DataType_input, typename _DataType_output>
196+
void dpnp_fft_fft_mathlib_c(const void* array1_in,
197+
void* result1,
198+
const long* input_shape,
199+
const size_t shape_size,
200+
const size_t result_size,
201+
const size_t norm)
202+
{
203+
if (!shape_size || !result_size || !array1_in || !result1)
204+
{
205+
return;
206+
}
207+
std::vector<std::int64_t> dimensions(input_shape, input_shape + shape_size);
208+
209+
if constexpr (std::is_same<_DataType_input, std::complex<double>>::value &&
210+
std::is_same<_DataType_output, std::complex<double>>::value)
211+
{
212+
if (shape_size == 1)
213+
{
214+
desc_dp_cmplx_t desc(result_size);
215+
dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
216+
array1_in, result1, shape_size, result_size, desc, norm);
217+
}
218+
else
219+
{
220+
desc_dp_cmplx_t desc(dimensions);
221+
dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
222+
array1_in, result1, shape_size, result_size, desc, norm);
223+
}
224+
}
225+
else if (std::is_same<_DataType_input, std::complex<float>>::value &&
226+
std::is_same<_DataType_output, std::complex<float>>::value)
227+
{
228+
if (shape_size == 1)
229+
{
230+
desc_sp_cmplx_t desc(result_size);
231+
dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
232+
array1_in, result1, shape_size, result_size, desc, norm);
233+
}
234+
else
235+
{
236+
desc_sp_cmplx_t desc(dimensions);
237+
dpnp_fft_fft_mathlib_compute_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
238+
array1_in, result1, shape_size, result_size, desc, norm);
239+
}
240+
}
241+
return;
242+
}
243+
244+
template <typename _DataType_input, typename _DataType_output>
245+
void dpnp_fft_fft_c(const void* array1_in,
246+
void* result1,
247+
const long* input_shape,
248+
const long* output_shape,
249+
size_t shape_size,
250+
long axis,
251+
long input_boundarie,
252+
size_t inverse,
253+
const size_t norm)
254+
{
255+
if (!shape_size)
256+
{
257+
return;
258+
}
259+
260+
const size_t result_size = std::accumulate(output_shape, output_shape + shape_size, 1, std::multiplies<size_t>());
261+
const size_t input_size = std::accumulate(input_shape, input_shape + shape_size, 1, std::multiplies<size_t>());
262+
263+
if (!input_size || !result_size || !array1_in || !result1)
264+
{
265+
return;
266+
}
267+
268+
if (((std::is_same<_DataType_input, std::complex<double>>::value &&
269+
std::is_same<_DataType_output, std::complex<double>>::value) ||
270+
(std::is_same<_DataType_input, std::complex<float>>::value &&
271+
std::is_same<_DataType_output, std::complex<float>>::value)) &&
272+
(shape_size <= 3))
273+
{
274+
dpnp_fft_fft_mathlib_c<_DataType_input, _DataType_output>(
275+
array1_in, result1, input_shape, shape_size, result_size, norm);
276+
}
277+
else
278+
{
279+
dpnp_fft_fft_sycl_c<_DataType_input, _DataType_output>(array1_in,
280+
result1,
281+
input_shape,
282+
output_shape,
283+
shape_size,
284+
result_size,
285+
input_size,
286+
axis,
287+
input_boundarie,
288+
inverse);
289+
}
166290

167291
return;
168292
}
@@ -173,12 +297,12 @@ void func_map_init_fft_func(func_map_t& fmap)
173297
(void*)dpnp_fft_fft_c<int, std::complex<double>>};
174298
fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_LNG][eft_LNG] = {eft_C128,
175299
(void*)dpnp_fft_fft_c<long, std::complex<double>>};
176-
fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_FLT][eft_FLT] = {eft_C128,
177-
(void*)dpnp_fft_fft_c<float, std::complex<double>>};
300+
fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_FLT][eft_FLT] = {eft_C64,
301+
(void*)dpnp_fft_fft_c<float, std::complex<float>>};
178302
fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_DBL][eft_DBL] = {eft_C128,
179303
(void*)dpnp_fft_fft_c<double, std::complex<double>>};
180304
fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_C64][eft_C64] = {
181-
eft_C128, (void*)dpnp_fft_fft_c<std::complex<float>, std::complex<double>>};
305+
eft_C64, (void*)dpnp_fft_fft_c<std::complex<float>, std::complex<float>>};
182306
fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_C128][eft_C128] = {
183307
eft_C128, (void*)dpnp_fft_fft_c<std::complex<double>, std::complex<double>>};
184308
return;

dpnp/fft/dpnp_algo_fft.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ __all__ = [
4141
"dpnp_fft"
4242
]
4343

44-
ctypedef void(*fptr_dpnp_fft_fft_t)(void *, void * , long * , long * , size_t, long, long, size_t)
44+
ctypedef void(*fptr_dpnp_fft_fft_t)(void *, void * , long * , long * , size_t, long, long, size_t, size_t)
4545

4646

4747
cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,
4848
size_t input_boundarie,
4949
size_t output_boundarie,
5050
long axis,
51-
size_t inverse):
51+
size_t inverse,
52+
size_t norm):
5253

5354
cdef shape_type_c input_shape = input.shape
5455
cdef shape_type_c output_shape = input_shape
@@ -68,6 +69,6 @@ cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,
6869
cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
6970
# call FPTR function
7071
func(input.get_data(), result.get_data(), input_shape.data(),
71-
output_shape.data(), input_shape.size(), axis_norm, input_boundarie, inverse)
72+
output_shape.data(), input_shape.size(), axis_norm, input_boundarie, inverse, norm)
7273

7374
return result

dpnp/fft/dpnp_iface_fft.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,19 @@ def fft(x1, n=None, axis=-1, norm=None):
7676
Limitations
7777
-----------
7878
Parameter ``norm`` is unsupported.
79-
Parameter ``x1`` supports ``dpnp.int32``, ``dpnp.int64``, ``dpnp.float32``, ``dpnp.float64`` and
80-
``dpnp.complex128`` datatypes only.
79+
Parameter ``x1`` supports ``dpnp.int32``, ``dpnp.int64``, ``dpnp.float32``, ``dpnp.float64``,
80+
``dpnp.complex64`` and ``dpnp.complex128`` datatypes only.
8181
8282
For full documentation refer to :obj:`numpy.fft.fft`.
8383
8484
"""
8585

8686
x1_desc = dpnp.get_dpnp_descriptor(x1)
8787
if x1_desc:
88+
# if norm is None or norm is 'backward':
89+
# norm_val = 0
90+
# else:
91+
# norm_val = 1
8892
if axis is None:
8993
axis_param = -1 # the most right dimension (default value)
9094
else:
@@ -104,7 +108,7 @@ def fft(x1, n=None, axis=-1, norm=None):
104108
else:
105109
output_boundarie = input_boundarie
106110

107-
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False).get_pyobj()
111+
return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, False, 0).get_pyobj()
108112

109113
return call_origin(numpy.fft.fft, x1, n, axis, norm)
110114

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,6 @@ tests/test_dparray.py::test_flatten[[]-int32]
399399
tests/test_dparray.py::test_flatten[[]-bool]
400400
tests/test_dparray.py::test_flatten[[]-bool_]
401401
tests/test_dparray.py::test_flatten[[]-complex]
402-
tests/test_fft.py::test_fft[complex128]
403-
tests/test_fft.py::test_fft[complex64]
404402
tests/test_fft.py::test_fft[float32]
405403
tests/test_fft.py::test_fft[float64]
406404
tests/test_fft.py::test_fft[int32]

tests/test_fft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_fft(type):
1515
dpnp_data = dpnp.array(data)
1616

1717
np_res = numpy.fft.fft(data)
18-
dpnp_res = dpnp.fft.fft(dpnp_data)
18+
dpnp_res = dpnp.asnumpy(dpnp.fft.fft(dpnp_data))
1919

2020
numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7)
2121
assert dpnp_res.dtype == np_res.dtype

0 commit comments

Comments
 (0)