Skip to content

Commit d522382

Browse files
committed
Fix product functions output Array type
For boolean/char inputs to ArrayFire, the output of product operation is char and everywhere else same as AggregateType alias.
1 parent 33644ff commit d522382

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/algorithm/mod.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ dim_reduce_func_def!(
143143
",
144144
product,
145145
af_product,
146-
T::AggregateOutType
146+
T::ProductOutType
147147
);
148148

149149
dim_reduce_func_def!(
@@ -440,10 +440,10 @@ where
440440
/// # Return Values
441441
///
442442
/// Array that is reduced along given dimension via multiplication operation
443-
pub fn product_nan<T>(input: &Array<T>, dim: i32, nanval: f64) -> Array<T::AggregateOutType>
443+
pub fn product_nan<T>(input: &Array<T>, dim: i32, nanval: f64) -> Array<T::ProductOutType>
444444
where
445445
T: HasAfEnum,
446-
T::AggregateOutType: HasAfEnum,
446+
T::ProductOutType: HasAfEnum,
447447
{
448448
let mut temp: i64 = 0;
449449
unsafe {

src/util.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,15 @@ pub trait HasAfEnum {
182182
/// aggregation of set of values for a given input type. Aggregate type
183183
/// alias points to below types for given input types:
184184
/// - `Self` for input types: `Complex<64>`, `Complex<f32>`, `f64`, `f32`, `i64`, `u64`
185-
/// - `f32` for input types: `bool`
185+
/// - `u32` for input types: `bool`
186186
/// - `u32` for input types: `u8`
187187
/// - `i32` for input types: `i16`
188188
/// - `u32` for input types: `u16`
189189
/// - `i32` for input types: `i32`
190190
/// - `u32` for input types: `u32`
191191
type AggregateOutType;
192+
/// This type is different for b8 input type
193+
type ProductOutType;
192194
/// This type alias points to the output type for given input type of
193195
/// sobel filter operation. Sobel filter output alias points to below
194196
/// types for given input types:
@@ -211,6 +213,7 @@ impl HasAfEnum for Complex<f32> {
211213
type ComplexOutType = Self;
212214
type MeanOutType = Self;
213215
type AggregateOutType = Self;
216+
type ProductOutType = Self;
214217
type SobelOutType = Self;
215218

216219
fn get_af_dtype() -> DType {
@@ -226,6 +229,7 @@ impl HasAfEnum for Complex<f64> {
226229
type ComplexOutType = Self;
227230
type MeanOutType = Self;
228231
type AggregateOutType = Self;
232+
type ProductOutType = Self;
229233
type SobelOutType = Self;
230234

231235
fn get_af_dtype() -> DType {
@@ -241,6 +245,7 @@ impl HasAfEnum for f32 {
241245
type ComplexOutType = Complex<f32>;
242246
type MeanOutType = Self;
243247
type AggregateOutType = Self;
248+
type ProductOutType = Self;
244249
type SobelOutType = Self;
245250

246251
fn get_af_dtype() -> DType {
@@ -256,6 +261,7 @@ impl HasAfEnum for f64 {
256261
type ComplexOutType = Complex<f64>;
257262
type MeanOutType = Self;
258263
type AggregateOutType = Self;
264+
type ProductOutType = Self;
259265
type SobelOutType = Self;
260266

261267
fn get_af_dtype() -> DType {
@@ -271,6 +277,7 @@ impl HasAfEnum for bool {
271277
type ComplexOutType = Complex<f32>;
272278
type MeanOutType = f32;
273279
type AggregateOutType = u32;
280+
type ProductOutType = bool;
274281
type SobelOutType = i32;
275282

276283
fn get_af_dtype() -> DType {
@@ -286,6 +293,7 @@ impl HasAfEnum for u8 {
286293
type ComplexOutType = Complex<f32>;
287294
type MeanOutType = f32;
288295
type AggregateOutType = u32;
296+
type ProductOutType = u32;
289297
type SobelOutType = i32;
290298

291299
fn get_af_dtype() -> DType {
@@ -301,6 +309,7 @@ impl HasAfEnum for i16 {
301309
type ComplexOutType = Complex<f32>;
302310
type MeanOutType = f32;
303311
type AggregateOutType = i32;
312+
type ProductOutType = i32;
304313
type SobelOutType = i32;
305314

306315
fn get_af_dtype() -> DType {
@@ -316,6 +325,7 @@ impl HasAfEnum for u16 {
316325
type ComplexOutType = Complex<f32>;
317326
type MeanOutType = f32;
318327
type AggregateOutType = u32;
328+
type ProductOutType = u32;
319329
type SobelOutType = i32;
320330

321331
fn get_af_dtype() -> DType {
@@ -331,6 +341,7 @@ impl HasAfEnum for i32 {
331341
type ComplexOutType = Complex<f32>;
332342
type MeanOutType = f32;
333343
type AggregateOutType = i32;
344+
type ProductOutType = i32;
334345
type SobelOutType = i32;
335346

336347
fn get_af_dtype() -> DType {
@@ -346,6 +357,7 @@ impl HasAfEnum for u32 {
346357
type ComplexOutType = Complex<f32>;
347358
type MeanOutType = f32;
348359
type AggregateOutType = u32;
360+
type ProductOutType = u32;
349361
type SobelOutType = i32;
350362

351363
fn get_af_dtype() -> DType {
@@ -361,6 +373,7 @@ impl HasAfEnum for i64 {
361373
type ComplexOutType = Complex<f64>;
362374
type MeanOutType = f64;
363375
type AggregateOutType = Self;
376+
type ProductOutType = Self;
364377
type SobelOutType = i64;
365378

366379
fn get_af_dtype() -> DType {
@@ -376,6 +389,7 @@ impl HasAfEnum for u64 {
376389
type ComplexOutType = Complex<f64>;
377390
type MeanOutType = f64;
378391
type AggregateOutType = Self;
392+
type ProductOutType = Self;
379393
type SobelOutType = i64;
380394

381395
fn get_af_dtype() -> DType {

0 commit comments

Comments
 (0)