@@ -150,7 +150,7 @@ def matmulTT(lhs, rhs):
150
150
MATPROP .TRANS .value , MATPROP .TRANS .value ))
151
151
return out
152
152
153
- def dot (lhs , rhs , lhs_opts = MATPROP .NONE , rhs_opts = MATPROP .NONE ):
153
+ def dot (lhs , rhs , lhs_opts = MATPROP .NONE , rhs_opts = MATPROP .NONE , return_scalar = False ):
154
154
"""
155
155
Dot product of two input vectors.
156
156
@@ -173,10 +173,13 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
173
173
- af.MATPROP.NONE - If no op should be done on `rhs`.
174
174
- No other options are currently supported.
175
175
176
+ return_scalar: optional: bool. default: False.
177
+ - When set to true, the input arrays are flattened and the output is a scalar
178
+
176
179
Returns
177
180
-------
178
181
179
- out : af.Array
182
+ out : af.Array or scalar
180
183
Output of dot product of `lhs` and `rhs`.
181
184
182
185
Note
@@ -186,7 +189,16 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
186
189
- Batches are not supported.
187
190
188
191
"""
189
- out = Array ()
190
- safe_call (backend .get ().af_dot (c_pointer (out .arr ), lhs .arr , rhs .arr ,
191
- lhs_opts .value , rhs_opts .value ))
192
- return out
192
+ if return_scalar :
193
+ real = c_double_t (0 )
194
+ imag = c_double_t (0 )
195
+ safe_call (backend .get ().af_dot_all (c_pointer (real ), c_pointer (imag ),
196
+ lhs .arr , rhs .arr , lhs_opts .value , rhs_opts .value ))
197
+ real = real .value
198
+ imag = imag .value
199
+ return real if imag == 0 else real + imag * 1j
200
+ else :
201
+ out = Array ()
202
+ safe_call (backend .get ().af_dot (c_pointer (out .arr ), lhs .arr , rhs .arr ,
203
+ lhs_opts .value , rhs_opts .value ))
204
+ return out
0 commit comments