Skip to content

Commit 78a0353

Browse files
authored
Merge pull request #94 from 9prady9/constant_like_and_Negtrait
New data generation functions and Neg trait implementation
2 parents babec9d + 10a20d1 commit 78a0353

File tree

4 files changed

+201
-6
lines changed

4 files changed

+201
-6
lines changed

src/arith/mod.rs

+28-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@ extern crate num;
33

44
use dim4::Dim4;
55
use array::Array;
6-
use defines::AfError;
6+
use defines::{AfError, DType, Scalar};
77
use error::HANDLE_ERROR;
88
use self::libc::{c_int};
9-
use data::{constant, tile};
9+
use data::{constant, constant_t, tile};
1010
use self::num::Complex;
1111

12+
use std::ops::Neg;
13+
14+
type Complex32 = Complex<f32>;
15+
type Complex64 = Complex<f64>;
1216
type MutAfArray = *mut self::libc::c_longlong;
1317
type MutDouble = *mut self::libc::c_double;
1418
type MutUint = *mut self::libc::c_uint;
@@ -485,3 +489,25 @@ bit_assign_func!(BitOrAssign, bitor_assign, bitor);
485489
bit_assign_func!(BitXorAssign, bitxor_assign, bitxor);
486490

487491
}
492+
493+
///Implement negation trait for Array
494+
impl Neg for Array {
495+
type Output = Array;
496+
497+
fn neg(self) -> Self::Output {
498+
match self.get_type() {
499+
DType::S64 => (constant_t(Scalar::S64(0 as i64), self.dims(), DType::S64) - self),
500+
DType::U64 => (constant_t(Scalar::U64(0 as u64), self.dims(), DType::U64) - self),
501+
DType::C32 => (constant_t(Scalar::C32(Complex32::new(0.0, 0.0)), self.dims(), DType::C32) - self),
502+
DType::C64 => (constant_t(Scalar::C64(Complex64::new(0.0, 0.0)), self.dims(), DType::C64) - self),
503+
DType::F32 => (constant_t(Scalar::F32(0 as f32), self.dims(), DType::F32) - self),
504+
DType::F64 => (constant_t(Scalar::F64(0 as f64), self.dims(), DType::F64) - self),
505+
DType::B8 => (constant_t(Scalar::B8 (false ), self.dims(), DType::B8 ) - self),
506+
DType::S32 => (constant_t(Scalar::S32(0 as i32), self.dims(), DType::S32) - self),
507+
DType::U32 => (constant_t(Scalar::U32(0 as u32), self.dims(), DType::U32) - self),
508+
DType::U8 => (constant_t(Scalar::U8 (0 as u8 ), self.dims(), DType::U8 ) - self),
509+
DType::S16 => (constant_t(Scalar::S16(0 as i16), self.dims(), DType::S16) - self),
510+
DType::U16 => (constant_t(Scalar::U16(0 as u16), self.dims(), DType::U16) - self),
511+
}
512+
}
513+
}

src/data/mod.rs

+137-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ extern crate num;
33

44
use array::Array;
55
use dim4::Dim4;
6-
use defines::AfError;
6+
use defines::{AfError, DType, Scalar};
77
use error::HANDLE_ERROR;
88
use self::libc::{uint8_t, c_int, c_uint, c_double};
99
use self::num::Complex;
@@ -622,3 +622,139 @@ pub fn replace_scalar(a: &mut Array, cond: &Array, b: f64) {
622622
HANDLE_ERROR(AfError::from(err_val));
623623
}
624624
}
625+
626+
/// Create a range of values of given type([DType](./enum.DType.html))
627+
///
628+
/// Creates an array with [0, n] values along the `seq_dim` which is tiled across other dimensions.
629+
///
630+
/// # Parameters
631+
///
632+
/// - `dims` is the size of Array
633+
/// - `seq_dim` is the dimension along which range values are populated, all values along other
634+
/// dimensions are just repeated
635+
/// - `dtype` indicates whats the type of the Array to be created
636+
///
637+
/// # Return Values
638+
/// Array
639+
#[allow(unused_mut)]
640+
pub fn range_t(dims: Dim4, seq_dim: i32, dtype: DType) -> Array {
641+
unsafe {
642+
let mut temp: i64 = 0;
643+
let err_val = af_range(&mut temp as MutAfArray,
644+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
645+
seq_dim as c_int, dtype as uint8_t);
646+
HANDLE_ERROR(AfError::from(err_val));
647+
Array::from(temp)
648+
}
649+
}
650+
651+
/// Create a range of values of given type([DType](./enum.DType.html))
652+
///
653+
/// Create an sequence [0, dims.elements() - 1] and modify to specified dimensions dims and then tile it according to tile_dims.
654+
///
655+
/// # Parameters
656+
///
657+
/// - `dims` is the dimensions of the sequence to be generated
658+
/// - `tdims` is the number of repitions of the unit dimensions
659+
/// - `dtype` indicates whats the type of the Array to be created
660+
///
661+
/// # Return Values
662+
///
663+
/// Array
664+
#[allow(unused_mut)]
665+
pub fn iota_t(dims: Dim4, tdims: Dim4, dtype: DType) -> Array {
666+
unsafe {
667+
let mut temp: i64 = 0;
668+
let err_val =af_iota(&mut temp as MutAfArray,
669+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
670+
tdims.ndims() as c_uint, tdims.get().as_ptr() as *const DimT,
671+
dtype as uint8_t);
672+
HANDLE_ERROR(AfError::from(err_val));
673+
Array::from(temp)
674+
}
675+
}
676+
677+
/// Create an identity array with 1's in diagonal of given type([DType](./enum.DType.html))
678+
///
679+
/// # Parameters
680+
///
681+
/// - `dims` is the output Array dimensions
682+
/// - `dtype` indicates whats the type of the Array to be created
683+
///
684+
/// # Return Values
685+
///
686+
/// Identity matrix
687+
#[allow(unused_mut)]
688+
pub fn identity_t(dims: Dim4, dtype: DType) -> Array {
689+
unsafe {
690+
let mut temp: i64 = 0;
691+
let err_val = af_identity(&mut temp as MutAfArray,
692+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
693+
dtype as uint8_t);
694+
HANDLE_ERROR(AfError::from(err_val));
695+
Array::from(temp)
696+
}
697+
}
698+
699+
/// Create a constant array of given type([DType](./enum.DType.html))
700+
///
701+
/// You can use this function to create arrays of type dictated by the enum
702+
/// [DType](./enum.DType.html) using the scalar `value` that has the shape similar
703+
/// to `dims`.
704+
///
705+
/// # Parameters
706+
///
707+
/// - `value` is the [Scalar](./enum.Scalar.html) to be filled into the array
708+
/// - `dims` is the output Array dimensions
709+
/// - `dtype` indicates the type of Array to be created and is the type of the scalar to be passed
710+
/// via the paramter `value`.
711+
///
712+
/// # Return Values
713+
///
714+
/// Array of `dims` shape and filed with given constant `value`.
715+
#[allow(unused_mut)]
716+
pub fn constant_t(value: Scalar, dims: Dim4, dtype: DType) -> Array {
717+
use Scalar::*;
718+
719+
// Below macro is only visible to this function
720+
// and it is used to abbreviate the repetitive const calls
721+
macro_rules! expand_const_call {
722+
($ffi_name: ident, $temp: expr, $v: expr, $dims: expr, $dt: expr) => ({
723+
$ffi_name(&mut $temp as MutAfArray, $v as c_double,
724+
$dims.ndims() as c_uint, $dims.get().as_ptr() as *const DimT, $dt)
725+
})
726+
}
727+
728+
unsafe {
729+
let dt = dtype as c_int;
730+
let mut temp: i64 = 0;
731+
let err_val = match value {
732+
C32(v) => {
733+
af_constant_complex(&mut temp as MutAfArray, v.re as c_double, v.im as c_double,
734+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT, dt)
735+
},
736+
C64(v) => {
737+
af_constant_complex(&mut temp as MutAfArray, v.re as c_double, v.im as c_double,
738+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT, dt)
739+
},
740+
S64(v) => {
741+
af_constant_long(&mut temp as MutAfArray, v as Intl,
742+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT)
743+
},
744+
U64(v) => {
745+
af_constant_ulong(&mut temp as MutAfArray, v as Uintl,
746+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT)
747+
},
748+
F32(v) => expand_const_call!(af_constant, temp, v, dims, dt),
749+
F64(v) => expand_const_call!(af_constant, temp, v, dims, dt),
750+
B8(v) => expand_const_call!(af_constant, temp, v as i32, dims, dt),
751+
S32(v) => expand_const_call!(af_constant, temp, v, dims, dt),
752+
U32(v) => expand_const_call!(af_constant, temp, v, dims, dt),
753+
U8(v) => expand_const_call!(af_constant, temp, v, dims, dt),
754+
S16(v) => expand_const_call!(af_constant, temp, v, dims, dt),
755+
U16(v) => expand_const_call!(af_constant, temp, v, dims, dt),
756+
};
757+
HANDLE_ERROR(AfError::from(err_val));
758+
Array::from(temp)
759+
}
760+
}

src/defines.rs

+32
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
extern crate num;
2+
13
use std::error::Error;
24
use std::fmt::{Display, Formatter};
35
use std::fmt::Error as FmtError;
6+
use self::num::Complex;
47

58
/// Error codes
69
#[repr(C)]
@@ -397,3 +400,32 @@ pub const PHILOX : RandomEngineType = RandomEngineType::PHILOX_4X32_10;
397400
pub const THREEFRY : RandomEngineType = RandomEngineType::THREEFRY_2X32_16;
398401
pub const MERSENNE : RandomEngineType = RandomEngineType::MERSENNE_GP11213;
399402
pub const DEFAULT_RANDOM_ENGINE : RandomEngineType = PHILOX;
403+
404+
/// Scalar value types
405+
#[derive(Clone, Copy, Debug, PartialEq)]
406+
pub enum Scalar {
407+
/// 32 bit float
408+
F32(f32),
409+
/// 32 bit complex float
410+
C32(Complex<f32>),
411+
/// 64 bit float
412+
F64(f64),
413+
/// 64 bit complex float
414+
C64(Complex<f64>),
415+
/// 8 bit boolean
416+
B8(bool),
417+
/// 32 bit signed integer
418+
S32(i32),
419+
/// 32 bit unsigned integer
420+
U32(u32),
421+
/// 8 bit unsigned integer
422+
U8(u8),
423+
/// 64 bit signed integer
424+
S64(i64),
425+
/// 64 bit unsigned integer
426+
U64(u64),
427+
/// 16 bit signed integer
428+
S16(i16),
429+
/// 16 bit unsigned integer
430+
U16(u16),
431+
}

src/lib.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ mod backend;
3232
pub use blas::{matmul, dot, transpose, transpose_inplace};
3333
mod blas;
3434

35-
pub use data::{constant, range, iota};
36-
pub use data::{identity, diag_create, diag_extract, lower, upper};
35+
pub use data::{constant, range, iota, identity};
36+
pub use data::{diag_create, diag_extract, lower, upper};
3737
pub use data::{join, join_many, tile};
3838
pub use data::{reorder, shift, moddims, flat, flip};
3939
pub use data::{select, selectl, selectr, replace, replace_scalar};
40+
pub use data::{range_t, iota_t, identity_t, constant_t};
4041
mod data;
4142

4243
pub use device::{get_version, info, init, device_count, is_double_available, set_device, get_device};
@@ -47,7 +48,7 @@ pub use defines::{DType, AfError, Backend, ColorMap, YCCStd, HomographyType};
4748
pub use defines::{InterpType, BorderType, MatchType, NormType};
4849
pub use defines::{Connectivity, ConvMode, ConvDomain, ColorSpace, MatProp};
4950
pub use defines::{MarkerType, MomentType, SparseFormat, BinaryOp, RandomEngineType};
50-
pub use defines::{PHILOX, THREEFRY, MERSENNE, DEFAULT_RANDOM_ENGINE};
51+
pub use defines::{PHILOX, THREEFRY, MERSENNE, DEFAULT_RANDOM_ENGINE, Scalar};
5152
mod defines;
5253

5354
pub use dim4::Dim4;

0 commit comments

Comments
 (0)