Skip to content

Commit b5d5c87

Browse files
committed
Cleaning up interop.py
1 parent ced0322 commit b5d5c87

File tree

1 file changed

+26
-16
lines changed

1 file changed

+26
-16
lines changed

arrayfire/interop.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,27 @@ def np_to_af_array(np_arr):
3737
---------
3838
af_arr : arrayfire.Array()
3939
"""
40+
41+
in_shape = np_arr.shape
42+
in_ptr = np_arr.ctypes.data
43+
in_dtype = np_arr.dtype.char
44+
4045
if (np_arr.flags['F_CONTIGUOUS']):
41-
return Array(np_arr.ctypes.data, np_arr.shape, np_arr.dtype.char)
46+
return Array(in_ptr, in_shape, in_dtype)
4247
elif (np_arr.flags['C_CONTIGUOUS']):
4348
if np_arr.ndim == 1:
44-
return Array(np_arr.ctypes.data, np_arr.shape, np_arr.dtype.char)
49+
return Array(in_ptr, in_shape, in_dtype)
4550
elif np_arr.ndim == 2:
46-
shape = (np_arr.shape[1], np_arr.shape[0])
47-
res = Array(np_arr.ctypes.data, shape, np_arr.dtype.char)
51+
shape = (in_shape[1], in_shape[0])
52+
res = Array(in_ptr, shape, in_dtype)
4853
return reorder(res, 1, 0)
4954
elif np_arr.ndim == 3:
50-
shape = (np_arr.shape[2], np_arr.shape[1], np_arr.shape[0])
51-
res = Array(np_arr.ctypes.data, shape, np_arr.dtype.char)
55+
shape = (in_shape[2], in_shape[1], in_shape[0])
56+
res = Array(in_ptr, shape, in_dtype)
5257
return reorder(res, 2, 1, 0)
5358
elif np_arr.ndim == 4:
54-
shape = (np_arr.shape[3], np_arr.shape[2], np_arr.shape[1], np_arr.shape[0])
55-
res = Array(np_arr.ctypes.data, shape, np_arr.dtype.char)
59+
shape = (in_shape[3], in_shape[2], in_shape[1], in_shape[0])
60+
res = Array(in_ptr, shape, in_dtype)
5661
return reorder(res, 3, 2, 1, 0)
5762
else:
5863
raise RuntimeError("Unsupported ndim")
@@ -79,26 +84,31 @@ def pycuda_to_af_array(pycu_arr):
7984
----------
8085
af_arr : arrayfire.Array()
8186
"""
87+
88+
in_ptr = pycu_arr.ptr
89+
in_shape = pycu_arr.shape
90+
in_dtype = pycu_arr.dtype.char
91+
8292
if (pycu_arr.flags.f_contiguous):
83-
res = Array(pycu_arr.ptr, pycu_arr.shape, pycu_arr.dtype.char, is_device=True)
93+
res = Array(in_ptr, in_shape, in_dtype, is_device=True)
8494
lock_array(res)
8595
return res
8696
elif (pycu_arr.flags.c_contiguous):
8797
if pycu_arr.ndim == 1:
88-
return Array(pycu_arr.ptr, pycu_arr.shape, pycu_arr.dtype.char, is_device=True)
98+
return Array(in_ptr, in_shape, in_dtype, is_device=True)
8999
elif pycu_arr.ndim == 2:
90-
shape = (pycu_arr.shape[1], pycu_arr.shape[0])
91-
res = Array(pycu_arr.ptr, shape, pycu_arr.dtype.char, is_device=True)
100+
shape = (in_shape[1], in_shape[0])
101+
res = Array(in_ptr, shape, in_dtype, is_device=True)
92102
lock_array(res)
93103
return reorder(res, 1, 0)
94104
elif pycu_arr.ndim == 3:
95-
shape = (pycu_arr.shape[2], pycu_arr.shape[1], pycu_arr.shape[0])
96-
res = Array(pycu_arr.ptr, shape, pycu_arr.dtype.char, is_device=True)
105+
shape = (in_shape[2], in_shape[1], in_shape[0])
106+
res = Array(in_ptr, shape, in_dtype, is_device=True)
97107
lock_array(res)
98108
return reorder(res, 2, 1, 0)
99109
elif pycu_arr.ndim == 4:
100-
shape = (pycu_arr.shape[3], pycu_arr.shape[2], pycu_arr.shape[1], pycu_arr.shape[0])
101-
res = Array(pycu_arr.ptr, shape, pycu_arr.dtype.char, is_device=True)
110+
shape = (in_shape[3], in_shape[2], in_shape[1], in_shape[0])
111+
res = Array(in_ptr, shape, in_dtype, is_device=True)
102112
lock_array(res)
103113
return reorder(res, 3, 2, 1, 0)
104114
else:

0 commit comments

Comments
 (0)