Skip to content

Commit c16e98a

Browse files
authored
manip module to desc (#777)
1 parent 795e934 commit c16e98a

File tree

3 files changed

+100
-123
lines changed

3 files changed

+100
-123
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ cpdef dparray dpnp_subtract(object x1_obj, object x2_obj, object dtype=*, dparra
302302
"""
303303
Array manipulation routines
304304
"""
305-
cpdef dparray dpnp_repeat(dparray array1, repeats, axes=*)
306-
cpdef dparray dpnp_transpose(dparray array1, axes=*)
305+
cpdef dparray dpnp_repeat(dpnp_descriptor array1, repeats, axes=*)
306+
cpdef dparray dpnp_transpose(dpnp_descriptor array1, axes=*)
307307

308308

309309
"""

dpnp/dpnp_algo/dpnp_algo_manipulation.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ cpdef dparray dpnp_atleast_3d(dparray arr):
8282
return arr
8383

8484

85-
cpdef dpnp_copyto(dparray dst, dparray src, where=True):
85+
cpdef dpnp_copyto(utils.dpnp_descriptor dst, utils.dpnp_descriptor src, where=True):
8686
# Convert string type names (dparray.dtype) to C enum DPNPFuncType
8787
cdef DPNPFuncType dst_type = dpnp_dtype_to_DPNPFuncType(dst.dtype)
8888
cdef DPNPFuncType src_type = dpnp_dtype_to_DPNPFuncType(src.dtype)
@@ -127,7 +127,7 @@ cpdef dparray dpnp_expand_dims(dparray in_array, axis):
127127
return result
128128

129129

130-
cpdef dparray dpnp_repeat(dparray array1, repeats, axes=None):
130+
cpdef dparray dpnp_repeat(utils.dpnp_descriptor array1, repeats, axes=None):
131131
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype)
132132

133133
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_REPEAT, param1_type, param1_type)
@@ -142,7 +142,7 @@ cpdef dparray dpnp_repeat(dparray array1, repeats, axes=None):
142142
return result
143143

144144

145-
cpdef dparray dpnp_transpose(dparray array1, axes=None):
145+
cpdef dparray dpnp_transpose(utils.dpnp_descriptor array1, axes=None):
146146
cdef dparray_shape_type input_shape = array1.shape
147147
cdef size_t input_shape_size = array1.ndim
148148
cdef dparray_shape_type result_shape = dparray_shape_type(input_shape_size, 1)
@@ -186,7 +186,7 @@ cpdef dparray dpnp_transpose(dparray array1, axes=None):
186186
return result
187187

188188

189-
cpdef dparray dpnp_squeeze(dparray in_array, axis):
189+
cpdef dparray dpnp_squeeze(object in_array, axis):
190190
shape_list = []
191191
if axis is None:
192192
for i in range(in_array.ndim):

dpnp/dpnp_iface_manipulation.py

Lines changed: 94 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@
4141

4242

4343
import collections.abc
44-
import numpy
4544

4645
from dpnp.dpnp_algo import *
4746
from dpnp.dparray import dparray
4847
from dpnp.dpnp_utils import *
49-
import dpnp
5048
from dpnp.dpnp_iface_arraycreation import array
5149

50+
import dpnp
51+
import numpy
52+
5253

5354
__all__ = [
5455
"asfarray",
@@ -67,7 +68,7 @@
6768
]
6869

6970

70-
def asfarray(a, dtype=numpy.float64):
71+
def asfarray(x1, dtype=numpy.float64):
7172
"""
7273
Return an array converted to a float type.
7374
@@ -79,18 +80,19 @@ def asfarray(a, dtype=numpy.float64):
7980
8081
"""
8182

82-
if not use_origin_backend(a):
83+
x1_desc = dpnp.get_dpnp_descriptor(x1)
84+
if x1_desc:
8385
# behavior of original function: int types replaced with float64
8486
if numpy.issubdtype(dtype, numpy.integer):
8587
dtype = numpy.float64
8688

8789
# if type is the same then same object should be returned
88-
if isinstance(a, dpnp.ndarray) and a.dtype == dtype:
89-
return a
90+
if x1_desc.dtype == dtype:
91+
return x1
9092

91-
return array(a, dtype=dtype)
93+
return array(x1, dtype=dtype)
9294

93-
return call_origin(numpy.asfarray, a, dtype)
95+
return call_origin(numpy.asfarray, x1, dtype)
9496

9597

9698
def atleast_1d(*arys):
@@ -188,32 +190,31 @@ def copyto(dst, src, casting='same_kind', where=True):
188190
Input array data types are limited by supported DPNP :ref:`Data types`.
189191
190192
"""
191-
if not use_origin_backend(dst):
192-
if not isinstance(dst, dparray):
193-
pass
194-
elif not isinstance(src, dparray):
195-
pass
196-
elif casting != 'same_kind':
193+
194+
dst_desc = dpnp.get_dpnp_descriptor(dst)
195+
src_desc = dpnp.get_dpnp_descriptor(src)
196+
if dst_desc and src_desc:
197+
if casting != 'same_kind':
197198
pass
198-
elif (dst.dtype == dpnp.bool and # due to 'same_kind' casting
199-
src.dtype in [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64, dpnp.complex128]):
199+
elif (dst_desc.dtype == dpnp.bool and # due to 'same_kind' casting
200+
src_desc.dtype in [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64, dpnp.complex128]):
200201
pass
201-
elif (dst.dtype in [dpnp.int32, dpnp.int64] and # due to 'same_kind' casting
202-
src.dtype in [dpnp.float32, dpnp.float64, dpnp.complex128]):
202+
elif (dst_desc.dtype in [dpnp.int32, dpnp.int64] and # due to 'same_kind' casting
203+
src_desc.dtype in [dpnp.float32, dpnp.float64, dpnp.complex128]):
203204
pass
204-
elif dst.dtype in [dpnp.float32, dpnp.float64] and src.dtype == dpnp.complex128: # due to 'same_kind' casting
205+
elif dst_desc.dtype in [dpnp.float32, dpnp.float64] and src_desc.dtype == dpnp.complex128: # due to 'same_kind' casting
205206
pass
206207
elif where is not True:
207208
pass
208-
elif dst.shape != src.shape:
209+
elif dst_desc.shape != src_desc.shape:
209210
pass
210211
else:
211-
return dpnp_copyto(dst, src, where=where)
212+
return dpnp_copyto(dst_desc, src_desc, where=where)
212213

213214
return call_origin(numpy.copyto, dst, src, casting, where)
214215

215216

216-
def expand_dims(a, axis):
217+
def expand_dims(x1, axis):
217218
"""
218219
Expand the shape of an array.
219220
@@ -271,13 +272,11 @@ def expand_dims(a, axis):
271272
272273
"""
273274

274-
if not use_origin_backend(a):
275-
if not isinstance(a, dpnp.ndarray):
276-
pass
277-
else:
278-
return dpnp_expand_dims(a, axis)
275+
x1_desc = dpnp.get_dpnp_descriptor(x1)
276+
if x1_desc:
277+
return dpnp_expand_dims(x1, axis)
279278

280-
return call_origin(numpy.expand_dims, a, axis)
279+
return call_origin(numpy.expand_dims, x1, axis)
281280

282281

283282
def moveaxis(x1, source, destination):
@@ -309,45 +308,33 @@ def moveaxis(x1, source, destination):
309308
310309
"""
311310

312-
if (use_origin_backend(x1)):
313-
return numpy.swapaxes(x1, source, destination)
314-
315-
if (not isinstance(x1, dparray)):
316-
return numpy.swapaxes(x1, source, destination)
317-
318-
if not isinstance(source, collections.abc.Sequence): # assume scalar
319-
source = (source,)
320-
321-
if not isinstance(destination, collections.abc.Sequence): # assume scalar
322-
destination = (destination,)
323-
324-
source_norm = normalize_axis(source, x1.ndim)
325-
destination_norm = normalize_axis(destination, x1.ndim)
311+
x1_desc = dpnp.get_dpnp_descriptor(x1)
312+
if x1_desc:
313+
source_norm = normalize_axis(source, x1_desc.ndim)
314+
destination_norm = normalize_axis(destination, x1_desc.ndim)
326315

327-
if len(source_norm) != len(destination_norm):
328-
checker_throw_axis_error(
329-
"swapaxes",
330-
"source_norm.size() != destination_norm.size()",
331-
source_norm,
332-
destination_norm)
316+
if len(source_norm) != len(destination_norm):
317+
pass
318+
else:
319+
# 'do nothing' pattern for transpose() with no elements in 'source'
320+
input_permute = []
321+
for i in range(x1_desc.ndim):
322+
if i not in source_norm:
323+
input_permute.append(i)
333324

334-
# 'do nothing' pattern for transpose() with no elements in 'source'
335-
input_permute = []
336-
for i in range(x1.ndim):
337-
if i not in source_norm:
338-
input_permute.append(i)
325+
# insert moving axes into proper positions
326+
for destination_id, source_id in sorted(zip(destination_norm, source_norm)):
327+
# if destination_id in input_permute:
328+
# pytest tests/third_party/cupy/manipulation_tests/test_transpose.py::TestTranspose::test_moveaxis_invalid5_3
329+
# checker_throw_value_error("swapaxes", "source_id exists", source_id, input_permute)
330+
input_permute.insert(destination_id, source_id)
339331

340-
# insert moving axes into proper positions
341-
for destination_id, source_id in sorted(zip(destination_norm, source_norm)):
342-
# if destination_id in input_permute:
343-
# pytest tests/third_party/cupy/manipulation_tests/test_transpose.py::TestTranspose::test_moveaxis_invalid5_3
344-
# checker_throw_value_error("swapaxes", "source_id exists", source_id, input_permute)
345-
input_permute.insert(destination_id, source_id)
332+
return transpose(x1_desc, axes=input_permute)
346333

347-
return transpose(x1, axes=input_permute)
334+
return call_origin(numpy.moveaxis, x1, source, destination)
348335

349336

350-
def ravel(a, order='C'):
337+
def ravel(x1, order='C'):
351338
"""
352339
Return a contiguous flattened array.
353340
@@ -369,12 +356,11 @@ def ravel(a, order='C'):
369356
370357
"""
371358

372-
if not use_origin_backend(a) and isinstance(a, dparray):
373-
return a.ravel(order=order)
359+
x1_desc = dpnp.get_dpnp_descriptor(x1)
360+
if x1_desc:
361+
return dpnp_flatten(x1)
374362

375-
result = numpy.rollaxis(dp2nd_array(a), order=order)
376-
377-
return nd2dp_array(result)
363+
return call_origin(numpy.ravel, x1, order=order)
378364

379365

380366
def repeat(x1, repeats, axis=None):
@@ -403,23 +389,22 @@ def repeat(x1, repeats, axis=None):
403389
404390
"""
405391

406-
if not use_origin_backend(x1):
407-
if not isinstance(x1, dparray):
408-
pass
409-
elif axis is not None and axis != 0:
392+
x1_desc = dpnp.get_dpnp_descriptor(x1)
393+
if x1_desc:
394+
if axis is not None and axis != 0:
410395
pass
411-
elif x1.ndim >= 2:
396+
elif x1_desc.ndim >= 2:
412397
pass
413398
elif not dpnp.isscalar(repeats) and len(repeats) > 1:
414399
pass
415400
else:
416401
repeat_val = repeats if dpnp.isscalar(repeats) else repeats[0]
417-
return dpnp_repeat(x1, repeat_val, axis)
402+
return dpnp_repeat(x1_desc, repeat_val, axis)
418403

419404
return call_origin(numpy.repeat, x1, repeats, axis)
420405

421406

422-
def rollaxis(a, axis, start=0):
407+
def rollaxis(x1, axis, start=0):
423408
"""
424409
Roll the specified axis backwards, until it lies in a given position.
425410
@@ -452,25 +437,22 @@ def rollaxis(a, axis, start=0):
452437
453438
"""
454439

455-
if not use_origin_backend(a):
456-
if not isinstance(a, dparray):
440+
x1_desc = dpnp.get_dpnp_descriptor(x1)
441+
if x1_desc:
442+
if not isinstance(axis, int):
457443
pass
458-
elif not isinstance(axis, int):
459-
pass
460-
elif start < -a.ndim or start > a.ndim:
444+
elif start < -x1_desc.ndim or start > x1_desc.ndim:
461445
pass
462446
else:
463-
start_norm = start + a.ndim if start < 0 else start
447+
start_norm = start + x1_desc.ndim if start < 0 else start
464448
destination = start_norm - 1 if start_norm > axis else start_norm
465449

466-
return dpnp.moveaxis(a, axis, destination)
467-
468-
result = numpy.rollaxis(dp2nd_array(a), axis, start)
450+
return dpnp.moveaxis(x1_desc, axis, destination)
469451

470-
return nd2dp_array(result)
452+
return call_origin(numpy.rollaxis, x1, axis, start)
471453

472454

473-
def squeeze(a, axis=None):
455+
def squeeze(x1, axis=None):
474456
"""
475457
Remove single-dimensional entries from the shape of an array.
476458
@@ -504,13 +486,11 @@ def squeeze(a, axis=None):
504486
505487
"""
506488

507-
if not use_origin_backend(a):
508-
if not isinstance(a, dpnp.ndarray):
509-
pass
510-
else:
511-
return dpnp_squeeze(a, axis)
489+
x1_desc = dpnp.get_dpnp_descriptor(x1)
490+
if x1_desc:
491+
return dpnp_squeeze(x1, axis)
512492

513-
return call_origin(numpy.squeeze, a, axis)
493+
return call_origin(numpy.squeeze, x1, axis)
514494

515495

516496
def swapaxes(x1, axis1, axis2):
@@ -539,24 +519,21 @@ def swapaxes(x1, axis1, axis2):
539519
540520
"""
541521

542-
if (use_origin_backend(x1)):
543-
return numpy.swapaxes(x1, axis1, axis2)
544-
545-
if (not isinstance(x1, dparray)):
546-
return numpy.swapaxes(x1, axis1, axis2)
547-
548-
if not (axis1 < x1.ndim):
549-
checker_throw_value_error("swapaxes", "axis1", axis1, x1.ndim - 1)
550-
551-
if not (axis2 < x1.ndim):
552-
checker_throw_value_error("swapaxes", "axis2", axis2, x1.ndim - 1)
522+
x1_desc = dpnp.get_dpnp_descriptor(x1)
523+
if x1_desc:
524+
if axis1 >= x1_desc.ndim:
525+
pass
526+
elif axis2 >= x1_desc.ndim:
527+
pass
528+
else:
529+
# 'do nothing' pattern for transpose()
530+
input_permute = [i for i in range(x1.ndim)]
531+
# swap axes
532+
input_permute[axis1], input_permute[axis2] = input_permute[axis2], input_permute[axis1]
553533

554-
# 'do nothing' pattern for transpose()
555-
input_permute = [i for i in range(x1.ndim)]
556-
# swap axes
557-
input_permute[axis1], input_permute[axis2] = input_permute[axis2], input_permute[axis1]
534+
return transpose(x1_desc, axes=input_permute)
558535

559-
return transpose(x1, axes=input_permute)
536+
return call_origin(numpy.swapaxes, x1, axis1, axis2)
560537

561538

562539
def transpose(x1, axes=None):
@@ -593,17 +570,17 @@ def transpose(x1, axes=None):
593570
594571
"""
595572

596-
if (use_origin_backend(x1)):
597-
return numpy.transpose(x1, axes=axes)
573+
x1_desc = dpnp.get_dpnp_descriptor(x1)
574+
if x1_desc:
575+
if axes is not None:
576+
if not any(axes):
577+
"""
578+
pytest tests/third_party/cupy/manipulation_tests/test_transpose.py
579+
"""
580+
axes = None
598581

599-
if (not isinstance(x1, dparray)):
600-
return numpy.transpose(x1, axes=axes)
582+
result = dpnp_transpose(x1_desc, axes)
601583

602-
if (axes is not None):
603-
if (not any(axes)):
604-
"""
605-
pytest tests/third_party/cupy/manipulation_tests/test_transpose.py
606-
"""
607-
axes = None
584+
return result
608585

609-
return dpnp_transpose(x1, axes=axes)
586+
return call_origin(numpy.transpose, x1, axes=axes)

0 commit comments

Comments
 (0)