Skip to content

Commit 91fd41f

Browse files
Apply review comments
1 parent 9130a53 commit 91fd41f

15 files changed

+40
-98
lines changed

dpnp/backend/extensions/statistics/bincount.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@
3232

3333
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3434

35-
namespace statistics
36-
{
37-
namespace histogram
35+
namespace statistics::histogram
3836
{
3937
struct Bincount
4038
{
@@ -62,5 +60,4 @@ struct Bincount
6260
};
6361

6462
void populate_bincount(py::module_ m);
65-
} // namespace histogram
66-
} // namespace statistics
63+
} // namespace statistics::histogram

dpnp/backend/extensions/statistics/common.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,8 @@
2929

3030
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3131

32-
namespace statistics
32+
namespace statistics::common
3333
{
34-
namespace common
35-
{
36-
3734
size_t get_max_local_size(const sycl::device &device)
3835
{
3936
constexpr const int default_max_cpu_local_size = 256;
@@ -120,5 +117,4 @@ pybind11::dtype dtype_from_typenum(int dst_typenum)
120117
}
121118
}
122119

123-
} // namespace common
124-
} // namespace statistics
120+
} // namespace statistics::common

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@
3636
#include "utils/math_utils.hpp"
3737
// clang-format on
3838

39-
namespace statistics
40-
{
41-
namespace common
39+
namespace statistics::common
4240
{
4341

4442
template <typename N, typename D>
@@ -200,5 +198,4 @@ sycl::nd_range<1>
200198
// headers of dpctl.
201199
pybind11::dtype dtype_from_typenum(int dst_typenum);
202200

203-
} // namespace common
204-
} // namespace statistics
201+
} // namespace statistics::common

dpnp/backend/extensions/statistics/dispatch_table.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,8 @@
3939
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
4040
namespace py = pybind11;
4141

42-
namespace statistics
42+
namespace statistics::common
4343
{
44-
namespace common
45-
{
46-
4744
template <typename T, typename Rest>
4845
struct one_of
4946
{
@@ -386,5 +383,4 @@ class DispatchTable2
386383
Table2<FnT> table;
387384
};
388385

389-
} // namespace common
390-
} // namespace statistics
386+
} // namespace statistics::common

dpnp/backend/extensions/statistics/histogram.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
#include <algorithm>
2727
#include <complex>
2828
#include <memory>
29-
#include <string>
30-
#include <type_traits>
31-
#include <unordered_map>
29+
#include <tuple>
3230
#include <vector>
3331

3432
#include <pybind11/pybind11.h>

dpnp/backend/extensions/statistics/histogram.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@
2929

3030
#include "dispatch_table.hpp"
3131

32-
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
32+
// namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3333

34-
namespace statistics
35-
{
36-
namespace histogram
34+
namespace statistics::histogram
3735
{
3836
struct Histogram
3937
{
@@ -59,5 +57,4 @@ struct Histogram
5957
};
6058

6159
void populate_histogram(py::module_ m);
62-
} // namespace histogram
63-
} // namespace statistics
60+
} // namespace statistics::histogram

dpnp/backend/extensions/statistics/histogram_common.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,9 @@
2626
#include <algorithm>
2727
#include <limits>
2828
#include <string>
29-
#include <unordered_map>
3029
#include <vector>
3130

3231
#include "dpctl4pybind11.hpp"
33-
#include "utils/memory_overlap.hpp"
34-
#include "utils/output_validation.hpp"
3532
#include "utils/type_dispatch.hpp"
3633

3734
#include <pybind11/pybind11.h>

dpnp/backend/extensions/statistics/histogram_common.hpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,6 @@
2929

3030
#include "common.hpp"
3131

32-
namespace dpctl
33-
{
34-
namespace tensor
35-
{
36-
class usm_ndarray;
37-
}
38-
} // namespace dpctl
39-
4032
using dpctl::tensor::usm_ndarray;
4133

4234
namespace statistics

dpnp/backend/extensions/statistics/sliding_dot_product1d.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,8 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26-
#include <algorithm>
2726
#include <complex>
2827
#include <memory>
29-
#include <string>
30-
#include <type_traits>
31-
#include <unordered_map>
3228
#include <vector>
3329

3430
#include <pybind11/pybind11.h>
@@ -42,7 +38,7 @@
4238
#include "sliding_dot_product1d.hpp"
4339
#include "sliding_window1d.hpp"
4440

45-
#include <iostream>
41+
// #include <iostream>
4642

4743
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
4844
using dpctl::tensor::usm_ndarray;
@@ -101,7 +97,9 @@ struct SlidingDotProductF
10197
}
10298
};
10399

104-
using SupportedTypes = std::tuple<uint64_t,
100+
using SupportedTypes = std::tuple<uint32_t,
101+
int32_t,
102+
uint64_t,
105103
int64_t,
106104
float,
107105
double,

dpnp/backend/extensions/statistics/sliding_dot_product1d.hpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,10 @@
2626
#pragma once
2727

2828
#include "dispatch_table.hpp"
29-
#include "utils/type_dispatch.hpp"
3029
#include <pybind11/pybind11.h>
3130
#include <sycl/sycl.hpp>
3231

33-
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
34-
35-
namespace statistics
36-
{
37-
namespace sliding_window1d
32+
namespace statistics::sliding_window1d
3833
{
3934
struct SlidingDotProduct1d
4035
{
@@ -62,5 +57,4 @@ struct SlidingDotProduct1d
6257
};
6358

6459
void populate_sliding_dot_product1d(py::module_ m);
65-
} // namespace sliding_window1d
66-
} // namespace statistics
60+
} // namespace statistics::sliding_window1d

dpnp/backend/extensions/statistics/sliding_window1d.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26-
#include <algorithm>
27-
#include <limits>
2826
#include <string>
29-
#include <unordered_map>
3027
#include <vector>
3128

3229
#include "dpctl4pybind11.hpp"

dpnp/backend/extensions/statistics/sliding_window1d.hpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,13 @@
2626
#pragma once
2727

2828
#include "utils/math_utils.hpp"
29-
#include <complex>
3029
#include <sycl/sycl.hpp>
31-
#include <tuple>
3230
#include <type_traits>
3331

3432
#include <stdio.h>
3533

3634
#include "common.hpp"
3735

38-
namespace dpctl
39-
{
40-
namespace tensor
41-
{
42-
class usm_ndarray;
43-
}
44-
} // namespace dpctl
45-
4636
using dpctl::tensor::usm_ndarray;
4737

4838
namespace statistics

dpnp/backend/extensions/statistics/validation_utils.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ sycl::queue get_queue(const std::vector<array_ptr> &inputs,
5353
}
5454
} // namespace
5555

56-
namespace statistics
57-
{
58-
namespace validation
56+
namespace statistics::validation
5957
{
6058
std::string name_of(const array_ptr &arr, const array_names &names)
6159
{
@@ -189,5 +187,4 @@ void common_checks(const std::vector<array_ptr> &inputs,
189187
check_no_overlap(inputs, outputs, names);
190188
}
191189

192-
} // namespace validation
193-
} // namespace statistics
190+
} // namespace statistics::validation

dpnp/backend/extensions/statistics/validation_utils.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131

3232
#include "dpctl4pybind11.hpp"
3333

34-
namespace statistics
35-
{
36-
namespace validation
34+
namespace statistics::validation
3735
{
3836
using array_ptr = const dpctl::tensor::usm_ndarray *;
3937
using array_names = std::unordered_map<array_ptr, std::string>;
@@ -69,5 +67,4 @@ void check_size_at_least(const array_ptr &arr,
6967
void common_checks(const std::vector<array_ptr> &inputs,
7068
const std::vector<array_ptr> &outputs,
7169
const array_names &names);
72-
} // namespace validation
73-
} // namespace statistics
70+
} // namespace statistics::validation

dpnp/dpnp_iface_statistics.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,7 @@ def corrcoef(x, y=None, rowvar=True, *, dtype=None):
440440

441441

442442
def _get_padding(a_size, v_size, mode):
443-
if v_size > a_size:
444-
a_size, v_size = v_size, a_size
443+
assert v_size > a_size
445444

446445
if mode == "valid":
447446
l_pad, r_pad = 0, 0
@@ -463,9 +462,8 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
463462

464463
usm_type = dpu.get_coerced_usm_type([a.usm_type, v.usm_type])
465464
out_size = l_pad + r_pad + a.size - v.size + 1
466-
out = dpnp.empty(
467-
shape=out_size, sycl_queue=queue, dtype=a.dtype, usm_type=usm_type
468-
)
465+
# out type is the same as input type
466+
out = dpnp.empty_like(a, shape=out_size, usm_type=usm_type)
469467

470468
a_usm = dpnp.get_usm_ndarray(a)
471469
v_usm = dpnp.get_usm_ndarray(v)
@@ -491,11 +489,11 @@ def correlate(a, v, mode="valid"):
491489
Cross-correlation of two 1-dimensional sequences.
492490
493491
This function computes the correlation as generally defined in signal
494-
processing texts [1]:
492+
processing texts [1]_:
495493
496494
.. math:: c_k = \sum_n a_{n+k} \cdot \overline{v}_n
497495
498-
with a and v sequences being zero-padded where necessary and
496+
with `a` and `v` sequences being zero-padded where necessary and
499497
:math:`\overline v` denoting complex conjugation.
500498
501499
For full documentation refer to :obj:`numpy.correlate`.
@@ -506,16 +504,16 @@ def correlate(a, v, mode="valid"):
506504
First input array.
507505
v : {dpnp.ndarray, usm_ndarray}
508506
Second input array.
509-
mode : {'valid', 'same', 'full'}, optional
507+
mode : {"valid", "same", "full"}, optional
510508
Refer to the :obj:`dpnp.convolve` docstring. Note that the default
511-
is ``'valid'``, unlike :obj:`dpnp.convolve`, which uses ``'full'``.
509+
is ``"valid"``, unlike :obj:`dpnp.convolve`, which uses ``"full"``.
512510
513-
Default: ``'valid'``.
511+
Default: ``"valid"``.
514512
515513
Notes
516514
-----
517515
The definition of correlation above is not unique and sometimes
518-
correlation may be defined differently. Another common definition is [1]:
516+
correlation may be defined differently. Another common definition is [1]_:
519517
520518
.. math:: c'_k = \sum_n a_{n} \cdot \overline{v_{n+k}}
521519
@@ -533,8 +531,8 @@ def correlate(a, v, mode="valid"):
533531
534532
See Also
535533
--------
536-
:obj:`dpnp.convolve` : Discrete, linear convolution of two
537-
one-dimensional sequences.
534+
:obj:`dpnp.convolve` : Discrete, linear convolution of two one-dimensional
535+
sequences.
538536
539537
540538
Examples
@@ -546,7 +544,7 @@ def correlate(a, v, mode="valid"):
546544
array([3.5], dtype=float32)
547545
>>> np.correlate(a, v, "same")
548546
array([2. , 3.5, 3. ], dtype=float32)
549-
>>> np.correlate([1, 2, 3], [0, 1, 0.5], "full")
547+
>>> np.correlate([a, v, "full")
550548
array([0.5, 2. , 3.5, 3. , 0. ], dtype=float32)
551549
552550
Using complex sequences:
@@ -557,10 +555,10 @@ def correlate(a, v, mode="valid"):
557555
array([0.5-0.5j, 1. +0.j , 1.5-1.5j, 3. -1.j , 0. +0.j ], dtype=complex64)
558556
559557
Note that you get the time reversed, complex conjugated result
560-
(:math:`\overline{c_{-k}}`) when the two input sequences a and v change
558+
(:math:`\overline{c_{-k}}`) when the two input sequences `a` and `v` change
561559
places:
562560
563-
>>> np.correlate([0, 1, 0.5j], [1+1j, 2, 3-1j], 'full')
561+
>>> np.correlate(vc, ac, 'full')
564562
array([0. +0.j , 3. +1.j , 1.5+1.5j, 1. +0.j , 0.5+0.5j], dtype=complex64)
565563
566564
"""
@@ -586,10 +584,11 @@ def correlate(a, v, mode="valid"):
586584

587585
if supported_dtype is None:
588586
raise ValueError(
589-
f"function '{correlate}' does not support input types "
590-
f"({a.dtype}, {v.dtype}), "
587+
f"function does not support input types "
588+
f"({a.dtype.name}, {v.dtype.name}), "
591589
"and the inputs could not be coerced to any "
592-
f"supported types. List of supported types: {supported_types}"
590+
f"supported types. List of supported types: "
591+
f"{[st.name for st in supported_types]}"
593592
)
594593

595594
if dpnp.issubdtype(v.dtype, dpnp.complexfloating):

0 commit comments

Comments
 (0)