@@ -68,7 +68,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
68
68
}
69
69
70
70
match ( from_type, to_type) {
71
- (
71
+ // TODO now just support signed numeric to decimal, support decimal to numeric later
72
+ ( Int8 | Int16 | Int32 | Int64 | Float32 | Float64 , Decimal ( _, _) )
73
+ | (
72
74
Null ,
73
75
Boolean
74
76
| Int8
@@ -870,6 +872,45 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
870
872
cast_with_options ( array, to_type, & DEFAULT_CAST_OPTIONS )
871
873
}
872
874
875
+ // cast the integer array to defined decimal data type array
876
+ macro_rules! cast_integer_to_decimal {
877
+ ( $ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => { {
878
+ let mut decimal_builder = DecimalBuilder :: new( $ARRAY. len( ) , * $PRECISION, * $SCALE) ;
879
+ let array = $ARRAY. as_any( ) . downcast_ref:: <$ARRAY_TYPE>( ) . unwrap( ) ;
880
+ let mul: i128 = 10_i128 . pow( * $SCALE as u32 ) ;
881
+ for i in 0 ..array. len( ) {
882
+ if array. is_null( i) {
883
+ decimal_builder. append_null( ) ?;
884
+ } else {
885
+ // convert i128 first
886
+ let v = array. value( i) as i128 ;
887
+ // if the input value is overflow, it will throw an error.
888
+ decimal_builder. append_value( mul * v) ?;
889
+ }
890
+ }
891
+ Ok ( Arc :: new( decimal_builder. finish( ) ) )
892
+ } } ;
893
+ }
894
+
895
+ // cast the floating-point array to defined decimal data type array
896
+ macro_rules! cast_floating_point_to_decimal {
897
+ ( $ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => { {
898
+ let mut decimal_builder = DecimalBuilder :: new( $ARRAY. len( ) , * $PRECISION, * $SCALE) ;
899
+ let array = $ARRAY. as_any( ) . downcast_ref:: <$ARRAY_TYPE>( ) . unwrap( ) ;
900
+ let mul = 10_f64 . powi( * $SCALE as i32 ) ;
901
+ for i in 0 ..array. len( ) {
902
+ if array. is_null( i) {
903
+ decimal_builder. append_null( ) ?;
904
+ } else {
905
+ let v = ( ( array. value( i) as f64 ) * mul) as i128 ;
906
+ // if the input value is overflow, it will throw an error.
907
+ decimal_builder. append_value( v) ?;
908
+ }
909
+ }
910
+ Ok ( Arc :: new( decimal_builder. finish( ) ) )
911
+ } } ;
912
+ }
913
+
873
914
/// Cast `array` to the provided data type and return a new Array with
874
915
/// type `to_type`, if possible. It accepts `CastOptions` to allow consumers
875
916
/// to configure cast behavior.
@@ -904,6 +945,34 @@ pub fn cast_with_options(
904
945
return Ok ( array. clone ( ) ) ;
905
946
}
906
947
match ( from_type, to_type) {
948
+ ( _, Decimal ( precision, scale) ) => {
949
+ // cast data to decimal
950
+ match from_type {
951
+ // TODO now just support signed numeric to decimal, support decimal to numeric later
952
+ Int8 => {
953
+ cast_integer_to_decimal ! ( array, Int8Array , precision, scale)
954
+ }
955
+ Int16 => {
956
+ cast_integer_to_decimal ! ( array, Int16Array , precision, scale)
957
+ }
958
+ Int32 => {
959
+ cast_integer_to_decimal ! ( array, Int32Array , precision, scale)
960
+ }
961
+ Int64 => {
962
+ cast_integer_to_decimal ! ( array, Int64Array , precision, scale)
963
+ }
964
+ Float32 => {
965
+ cast_floating_point_to_decimal ! ( array, Float32Array , precision, scale)
966
+ }
967
+ Float64 => {
968
+ cast_floating_point_to_decimal ! ( array, Float64Array , precision, scale)
969
+ }
970
+ _ => Err ( ArrowError :: CastError ( format ! (
971
+ "Casting from {:?} to {:?} not supported" ,
972
+ from_type, to_type
973
+ ) ) ) ,
974
+ }
975
+ }
907
976
(
908
977
Null ,
909
978
Boolean
@@ -2074,7 +2143,7 @@ fn cast_string_to_date64<Offset: StringOffsetSizeTrait>(
2074
2143
if string_array. is_null ( i) {
2075
2144
Ok ( None )
2076
2145
} else {
2077
- let string = string_array
2146
+ let string = string_array
2078
2147
. value ( i) ;
2079
2148
2080
2149
let result = string
@@ -2291,7 +2360,7 @@ fn dictionary_cast<K: ArrowDictionaryKeyType>(
2291
2360
return Err ( ArrowError :: CastError ( format ! (
2292
2361
"Unsupported type {:?} for dictionary index" ,
2293
2362
to_index_type
2294
- ) ) )
2363
+ ) ) ) ;
2295
2364
}
2296
2365
} ;
2297
2366
@@ -2655,6 +2724,115 @@ where
2655
2724
mod tests {
2656
2725
use super :: * ;
2657
2726
use crate :: { buffer:: Buffer , util:: display:: array_value_to_string} ;
2727
+ use num:: traits:: Pow ;
2728
+
2729
+ #[ test]
2730
+ fn test_cast_numeric_to_decimal ( ) {
2731
+ // test cast type
2732
+ let data_types = vec ! [
2733
+ DataType :: Int8 ,
2734
+ DataType :: Int16 ,
2735
+ DataType :: Int32 ,
2736
+ DataType :: Int64 ,
2737
+ DataType :: Float32 ,
2738
+ DataType :: Float64 ,
2739
+ ] ;
2740
+ let decimal_type = DataType :: Decimal ( 38 , 6 ) ;
2741
+ for data_type in data_types {
2742
+ assert ! ( can_cast_types( & data_type, & decimal_type) )
2743
+ }
2744
+ assert ! ( !can_cast_types( & DataType :: UInt64 , & decimal_type) ) ;
2745
+
2746
+ // test cast data
2747
+ let input_datas = vec ! [
2748
+ Arc :: new( Int8Array :: from( vec![
2749
+ Some ( 1 ) ,
2750
+ Some ( 2 ) ,
2751
+ Some ( 3 ) ,
2752
+ None ,
2753
+ Some ( 5 ) ,
2754
+ ] ) ) as ArrayRef , // i8
2755
+ Arc :: new( Int16Array :: from( vec![
2756
+ Some ( 1 ) ,
2757
+ Some ( 2 ) ,
2758
+ Some ( 3 ) ,
2759
+ None ,
2760
+ Some ( 5 ) ,
2761
+ ] ) ) as ArrayRef , // i16
2762
+ Arc :: new( Int32Array :: from( vec![
2763
+ Some ( 1 ) ,
2764
+ Some ( 2 ) ,
2765
+ Some ( 3 ) ,
2766
+ None ,
2767
+ Some ( 5 ) ,
2768
+ ] ) ) as ArrayRef , // i32
2769
+ Arc :: new( Int64Array :: from( vec![
2770
+ Some ( 1 ) ,
2771
+ Some ( 2 ) ,
2772
+ Some ( 3 ) ,
2773
+ None ,
2774
+ Some ( 5 ) ,
2775
+ ] ) ) as ArrayRef , // i64
2776
+ ] ;
2777
+
2778
+ // i8, i16, i32, i64
2779
+ for array in input_datas {
2780
+ let casted_array = cast ( & array, & decimal_type) . unwrap ( ) ;
2781
+ let decimal_array = casted_array
2782
+ . as_any ( )
2783
+ . downcast_ref :: < DecimalArray > ( )
2784
+ . unwrap ( ) ;
2785
+ assert_eq ! ( & decimal_type, decimal_array. data_type( ) ) ;
2786
+ for i in 0 ..array. len ( ) {
2787
+ if i == 3 {
2788
+ assert ! ( decimal_array. is_null( i as usize ) ) ;
2789
+ } else {
2790
+ assert_eq ! (
2791
+ 10_i128 . pow( 6 ) * ( i as i128 + 1 ) ,
2792
+ decimal_array. value( i as usize )
2793
+ ) ;
2794
+ }
2795
+ }
2796
+ }
2797
+
2798
+ // test i8 to decimal type with overflow the result type
2799
+ // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3.
2800
+ let array = Int8Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 100 ] ) ;
2801
+ let array = Arc :: new ( array) as ArrayRef ;
2802
+ let casted_array = cast ( & array, & DataType :: Decimal ( 3 , 1 ) ) ;
2803
+ assert ! ( casted_array. is_err( ) ) ;
2804
+ assert_eq ! ( "Invalid argument error: The value of 1000 i128 is not compatible with Decimal(3,1)" , casted_array. unwrap_err( ) . to_string( ) ) ;
2805
+
2806
+ // test f32 to decimal type
2807
+ let f_data: Vec < f32 > = vec ! [ 1.1 , 2.2 , 4.4 , 1.123_456_8 ] ;
2808
+ let array = Float32Array :: from ( f_data. clone ( ) ) ;
2809
+ let array = Arc :: new ( array) as ArrayRef ;
2810
+ let casted_array = cast ( & array, & decimal_type) . unwrap ( ) ;
2811
+ let decimal_array = casted_array
2812
+ . as_any ( )
2813
+ . downcast_ref :: < DecimalArray > ( )
2814
+ . unwrap ( ) ;
2815
+ assert_eq ! ( & decimal_type, decimal_array. data_type( ) ) ;
2816
+ for ( i, item) in f_data. iter ( ) . enumerate ( ) . take ( array. len ( ) ) {
2817
+ let left = ( * item as f64 ) * 10_f64 . pow ( 6 ) ;
2818
+ assert_eq ! ( left as i128 , decimal_array. value( i as usize ) ) ;
2819
+ }
2820
+
2821
+ // test f64 to decimal type
2822
+ let f_data: Vec < f64 > = vec ! [ 1.1 , 2.2 , 4.4 , 1.123_456_789_123_4 ] ;
2823
+ let array = Float64Array :: from ( f_data. clone ( ) ) ;
2824
+ let array = Arc :: new ( array) as ArrayRef ;
2825
+ let casted_array = cast ( & array, & decimal_type) . unwrap ( ) ;
2826
+ let decimal_array = casted_array
2827
+ . as_any ( )
2828
+ . downcast_ref :: < DecimalArray > ( )
2829
+ . unwrap ( ) ;
2830
+ assert_eq ! ( & decimal_type, decimal_array. data_type( ) ) ;
2831
+ for ( i, item) in f_data. iter ( ) . enumerate ( ) . take ( array. len ( ) ) {
2832
+ let left = ( * item as f64 ) * 10_f64 . pow ( 6 ) ;
2833
+ assert_eq ! ( left as i128 , decimal_array. value( i as usize ) ) ;
2834
+ }
2835
+ }
2658
2836
2659
2837
#[ test]
2660
2838
fn test_cast_i32_to_f64 ( ) {
0 commit comments