Skip to content

Commit 7bf28b7

Browse files
committed
FEAT: Adding support for dot
1 parent 8bf7c01 commit 7bf28b7

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

arrayfire/blas.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def matmulTT(lhs, rhs):
150150
MATPROP.TRANS.value, MATPROP.TRANS.value))
151151
return out
152152

153-
def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
153+
def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE, return_scalar = False):
154154
"""
155155
Dot product of two input vectors.
156156
@@ -173,10 +173,13 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
173173
- af.MATPROP.NONE - If no op should be done on `rhs`.
174174
- No other options are currently supported.
175175
176+
return_scalar: optional: bool. default: False.
177+
- When set to true, the input arrays are flattened and the output is a scalar
178+
176179
Returns
177180
-------
178181
179-
out : af.Array
182+
out : af.Array or scalar
180183
Output of dot product of `lhs` and `rhs`.
181184
182185
Note
@@ -186,7 +189,16 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
186189
- Batches are not supported.
187190
188191
"""
189-
out = Array()
190-
safe_call(backend.get().af_dot(c_pointer(out.arr), lhs.arr, rhs.arr,
191-
lhs_opts.value, rhs_opts.value))
192-
return out
192+
if return_scalar:
193+
real = c_double_t(0)
194+
imag = c_double_t(0)
195+
safe_call(backend.get().af_dot_all(c_pointer(real), c_pointer(imag),
196+
lhs.arr, rhs.arr, lhs_opts.value, rhs_opts.value))
197+
real = real.value
198+
imag = imag.value
199+
return real if imag == 0 else real + imag * 1j
200+
else:
201+
out = Array()
202+
safe_call(backend.get().af_dot(c_pointer(out.arr), lhs.arr, rhs.arr,
203+
lhs_opts.value, rhs_opts.value))
204+
return out

0 commit comments

Comments
 (0)