@@ -3,7 +3,7 @@ extern crate num;
3
3
4
4
use array:: Array ;
5
5
use dim4:: Dim4 ;
6
- use defines:: AfError ;
6
+ use defines:: { AfError , DType , Scalar } ;
7
7
use error:: HANDLE_ERROR ;
8
8
use self :: libc:: { uint8_t, c_int, c_uint, c_double} ;
9
9
use self :: num:: Complex ;
@@ -623,24 +623,137 @@ pub fn replace_scalar(a: &mut Array, cond: &Array, b: f64) {
623
623
}
624
624
}
625
625
626
- /// Create an array filled with given constant retaining type/shape of another Array.
626
+ /// Create a range of values of given type([DType](./enum.DType.html))
627
+ ///
628
+ /// Creates an array with [0, n] values along the `seq_dim` which is tiled across other dimensions.
629
+ ///
630
+ /// # Parameters
631
+ ///
632
+ /// - `dims` is the size of Array
633
+ /// - `seq_dim` is the dimension along which range values are populated, all values along other
634
+ /// dimensions are just repeated
635
+ /// - `dtype` indicates whats the type of the Array to be created
636
+ ///
637
+ /// # Return Values
638
+ /// Array
639
+ #[ allow( unused_mut) ]
640
+ pub fn range_t ( dims : Dim4 , seq_dim : i32 , dtype : DType ) -> Array {
641
+ unsafe {
642
+ let mut temp: i64 = 0 ;
643
+ let err_val = af_range ( & mut temp as MutAfArray ,
644
+ dims. ndims ( ) as c_uint , dims. get ( ) . as_ptr ( ) as * const DimT ,
645
+ seq_dim as c_int , dtype as uint8_t ) ;
646
+ HANDLE_ERROR ( AfError :: from ( err_val) ) ;
647
+ Array :: from ( temp)
648
+ }
649
+ }
650
+
651
+ /// Create a range of values of given type([DType](./enum.DType.html))
652
+ ///
653
+ /// Create an sequence [0, dims.elements() - 1] and modify to specified dimensions dims and then tile it according to tile_dims.
654
+ ///
655
+ /// # Parameters
656
+ ///
657
+ /// - `dims` is the dimensions of the sequence to be generated
658
+ /// - `tdims` is the number of repitions of the unit dimensions
659
+ /// - `dtype` indicates whats the type of the Array to be created
660
+ ///
661
+ /// # Return Values
662
+ ///
663
+ /// Array
664
+ #[ allow( unused_mut) ]
665
+ pub fn iota_t ( dims : Dim4 , tdims : Dim4 , dtype : DType ) -> Array {
666
+ unsafe {
667
+ let mut temp: i64 = 0 ;
668
+ let err_val =af_iota ( & mut temp as MutAfArray ,
669
+ dims. ndims ( ) as c_uint , dims. get ( ) . as_ptr ( ) as * const DimT ,
670
+ tdims. ndims ( ) as c_uint , tdims. get ( ) . as_ptr ( ) as * const DimT ,
671
+ dtype as uint8_t ) ;
672
+ HANDLE_ERROR ( AfError :: from ( err_val) ) ;
673
+ Array :: from ( temp)
674
+ }
675
+ }
676
+
677
+ /// Create an identity array with 1's in diagonal of given type([DType](./enum.DType.html))
627
678
///
628
679
/// # Parameters
629
680
///
630
- /// - `value ` is the constant with which output Array is to be filled
631
- /// - `input` is the Array whose shape the output Array has to maintain
681
+ /// - `dims ` is the output Array dimensions
682
+ /// - `dtype` indicates whats the type of the Array to be created
632
683
///
633
684
/// # Return Values
634
685
///
635
- /// Array with given constant value and input Array's shape and similar internal data type.
636
- pub fn constant_like ( value : f64 , input : & Array ) -> Array {
637
- let dims = input. dims ( ) ;
686
+ /// Identity matrix
687
+ #[ allow( unused_mut) ]
688
+ pub fn identity_t ( dims : Dim4 , dtype : DType ) -> Array {
689
+ unsafe {
690
+ let mut temp: i64 = 0 ;
691
+ let err_val = af_identity ( & mut temp as MutAfArray ,
692
+ dims. ndims ( ) as c_uint , dims. get ( ) . as_ptr ( ) as * const DimT ,
693
+ dtype as uint8_t ) ;
694
+ HANDLE_ERROR ( AfError :: from ( err_val) ) ;
695
+ Array :: from ( temp)
696
+ }
697
+ }
698
+
699
+ /// Create a constant array of given type([DType](./enum.DType.html))
700
+ ///
701
+ /// You can use this function to create arrays of type dictated by the enum
702
+ /// [DType](./enum.DType.html) using the scalar `value` that has the shape similar
703
+ /// to `dims`.
704
+ ///
705
+ /// # Parameters
706
+ ///
707
+ /// - `value` is the [Scalar](./enum.Scalar.html) to be filled into the array
708
+ /// - `dims` is the output Array dimensions
709
+ /// - `dtype` indicates the type of Array to be created and is the type of the scalar to be passed
710
+ /// via the paramter `value`.
711
+ ///
712
+ /// # Return Values
713
+ ///
714
+ /// Array of `dims` shape and filed with given constant `value`.
715
+ #[ allow( unused_mut) ]
716
+ pub fn constant_t ( value : Scalar , dims : Dim4 , dtype : DType ) -> Array {
717
+ use Scalar :: * ;
718
+
719
+ // Below macro is only visible to this function
720
+ // and it is used to abbreviate the repetitive const calls
721
+ macro_rules! expand_const_call {
722
+ ( $ffi_name: ident, $temp: expr, $v: expr, $dims: expr, $dt: expr) => ( {
723
+ $ffi_name( & mut $temp as MutAfArray , $v as c_double,
724
+ $dims. ndims( ) as c_uint, $dims. get( ) . as_ptr( ) as * const DimT , $dt)
725
+ } )
726
+ }
727
+
638
728
unsafe {
729
+ let dt = dtype as c_int ;
639
730
let mut temp: i64 = 0 ;
640
- let err_val = af_constant ( & mut temp as MutAfArray , value as c_double ,
641
- dims. ndims ( ) as c_uint ,
642
- dims. get ( ) . as_ptr ( ) as * const DimT ,
643
- input. get_type ( ) as c_int ) ;
731
+ let err_val = match value {
732
+ C32 ( v) => {
733
+ af_constant_complex ( & mut temp as MutAfArray , v. re as c_double , v. im as c_double ,
734
+ dims. ndims ( ) as c_uint , dims. get ( ) . as_ptr ( ) as * const DimT , dt)
735
+ } ,
736
+ C64 ( v) => {
737
+ af_constant_complex ( & mut temp as MutAfArray , v. re as c_double , v. im as c_double ,
738
+ dims. ndims ( ) as c_uint , dims. get ( ) . as_ptr ( ) as * const DimT , dt)
739
+ } ,
740
+ S64 ( v) => {
741
+ af_constant_long ( & mut temp as MutAfArray , v as Intl ,
742
+ dims. ndims ( ) as c_uint , dims. get ( ) . as_ptr ( ) as * const DimT )
743
+ } ,
744
+ U64 ( v) => {
745
+ af_constant_ulong ( & mut temp as MutAfArray , v as Uintl ,
746
+ dims. ndims ( ) as c_uint , dims. get ( ) . as_ptr ( ) as * const DimT )
747
+ } ,
748
+ F32 ( v) => expand_const_call ! ( af_constant, temp, v, dims, dt) ,
749
+ F64 ( v) => expand_const_call ! ( af_constant, temp, v, dims, dt) ,
750
+ B8 ( v) => expand_const_call ! ( af_constant, temp, v as i32 , dims, dt) ,
751
+ S32 ( v) => expand_const_call ! ( af_constant, temp, v, dims, dt) ,
752
+ U32 ( v) => expand_const_call ! ( af_constant, temp, v, dims, dt) ,
753
+ U8 ( v) => expand_const_call ! ( af_constant, temp, v, dims, dt) ,
754
+ S16 ( v) => expand_const_call ! ( af_constant, temp, v, dims, dt) ,
755
+ U16 ( v) => expand_const_call ! ( af_constant, temp, v, dims, dt) ,
756
+ } ;
644
757
HANDLE_ERROR ( AfError :: from ( err_val) ) ;
645
758
Array :: from ( temp)
646
759
}
0 commit comments