diff --git a/src/expressions.rs b/src/expressions.rs index 70b2be8..53dee7c 100644 --- a/src/expressions.rs +++ b/src/expressions.rs @@ -103,79 +103,43 @@ fn tdigest_fields() -> Vec { // chunks // } -// Todo support other numerical types -#[polars_expr(output_type_func=tdigest_output)] -fn tdigest(inputs: &[Series]) -> PolarsResult { - let series = &inputs[0]; - // TODO: pooling is not feasible on small datasets - let chunks = match series.dtype() { - DataType::Float64 => { - let values = series.f64()?; - let chunks: Vec = POOL.install(|| { - values - .downcast_iter() - .par_bridge() - .map(|chunk| { - let t = TDigest::new_with_size(100); - let array = chunk.as_any().downcast_ref::().unwrap(); - let val_vec: Vec = array.non_null_values_iter().collect(); - t.merge_unsorted(val_vec.to_owned()) - }) - .collect::>() - }); - chunks - } - DataType::Float32 => { - let values = series.f32()?; - let chunks: Vec = POOL.install(|| { - values - .downcast_iter() - .par_bridge() - .map(|chunk| { - let t = TDigest::new_with_size(100); - let array = chunk.as_any().downcast_ref::().unwrap(); - let val_vec: Vec = - array.non_null_values_iter().map(|v| (v as f64)).collect(); - t.merge_unsorted(val_vec.to_owned()) - }) - .collect::>() - }); - chunks - } - DataType::Int64 => { - let values = series.i64()?; +macro_rules! gen { + ($func:ident, $a_f64:ident, $a_Float64Array: ident) => { + fn $func(series: &Series) -> PolarsResult> { + let values = series.$a_f64()?; let chunks: Vec = POOL.install(|| { values .downcast_iter() .par_bridge() .map(|chunk| { let t = TDigest::new_with_size(100); - let array = chunk.as_any().downcast_ref::().unwrap(); - let val_vec: Vec = - array.non_null_values_iter().map(|v| (v as f64)).collect(); + let array = chunk.as_any().downcast_ref::<$a_Float64Array>().unwrap(); + let val_vec: Vec = array.non_null_values_iter().map(|v| (v as f64)).collect(); t.merge_unsorted(val_vec.to_owned()) }) .collect::>() }); - chunks - } - DataType::Int32 => { - let values = series.i32()?; - let chunks: Vec = POOL.install(|| { - values - .downcast_iter() - .par_bridge() - .map(|chunk| { - let t = TDigest::new_with_size(100); - let array = chunk.as_any().downcast_ref::().unwrap(); - let val_vec: Vec = - array.non_null_values_iter().map(|v| (v as f64)).collect(); - t.merge_unsorted(val_vec.to_owned()) - }) - .collect::>() - }); - chunks + Ok(chunks) } + }; +} + +gen!(gen_f64, f64, Float64Array); +gen!(gen_f32, f32, Float32Array); +gen!(gen_i64, i64, Int64Array); +gen!(gen_i32, i32, Int32Array); + + +// Todo support other numerical types +#[polars_expr(output_type_func=tdigest_output)] +fn tdigest(inputs: &[Series]) -> PolarsResult { + let series = &inputs[0]; + // TODO: pooling is not feasible on small datasets + let chunks = match series.dtype() { + DataType::Float64 => gen_f64(series)?, + DataType::Float32 => gen_f32(series)?, + DataType::Int64 => gen_i64(series)?, + DataType::Int32 => gen_i32(series)?, _ => polars_bail!(InvalidOperation: "only supported for numerical types"), };