From 41a8668df489c75bee30e6e0808c69d7f114177f Mon Sep 17 00:00:00 2001 From: Thomas Lutterbeck Date: Mon, 18 Nov 2024 12:09:36 +0100 Subject: [PATCH 1/2] added function to merge tdigest with small refactoring --- polars_tdigest/__init__.py | 10 ++++++++++ src/expressions.rs | 35 ++++++++++++++++++++++++++++++----- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/polars_tdigest/__init__.py b/polars_tdigest/__init__.py index def9445..aa58148 100644 --- a/polars_tdigest/__init__.py +++ b/polars_tdigest/__init__.py @@ -52,3 +52,13 @@ def tdigest_cast(expr: IntoExpr) -> pl.Expr: is_elementwise=False, returns_scalar=True, ) + + +def merge_tdigests(expr: IntoExpr) -> pl.Expr: + return register_plugin_function( + plugin_path=Path(__file__).parent, + function_name="merge_tdigests", + args=expr, + is_elementwise=False, + returns_scalar=True, + ) diff --git a/src/expressions.rs b/src/expressions.rs index 166e578..19a7bcf 100644 --- a/src/expressions.rs +++ b/src/expressions.rs @@ -250,9 +250,7 @@ struct MergeTDKwargs { quantile: f64, } -// TODO this should check the type of the series and also work on series of Type f64 -#[polars_expr(output_type=Float64)] -fn estimate_quantile(inputs: &[Series], kwargs: MergeTDKwargs) -> PolarsResult { +fn extract_tdigest_vec(inputs: &[Series]) -> Vec { let mut df = inputs[0].clone().into_frame(); df.set_column_names(vec!["tdigest"].as_slice()).unwrap(); let mut buf = BufWriter::new(Vec::new()); @@ -262,8 +260,35 @@ fn estimate_quantile(inputs: &[Series], kwargs: MergeTDKwargs) -> PolarsResult = - serde_json::from_str(&json_str).expect("Failed to parse the tdigest JSON string"); + + serde_json::from_str(&json_str).expect("Failed to parse the tdigest JSON string") +} + +#[polars_expr(output_type_func=tdigest_output)] +fn merge_tdigests(inputs: &[Series]) -> PolarsResult { + let series = &inputs[0]; + let tdigest_json: Vec = extract_tdigest_vec(inputs); + + let tdigests: Vec = tdigest_json.into_iter().map(|td| td.tdigest).collect(); + let tdigest = TDigest::merge_digests(tdigests); + + let td_json = serde_json::to_string(&tdigest).unwrap(); + + let file = Cursor::new(&td_json); + let df = JsonReader::new(file) + .with_json_format(JsonFormat::JsonLines) + .infer_schema_len(Some(3)) + .with_batch_size(NonZeroUsize::new(3).unwrap()) + .finish() + .unwrap(); + + Ok(df.into_struct(series.name()).into_series()) +} + +// TODO this should check the type of the series and also work on series of Type f64 +#[polars_expr(output_type=Float64)] +fn estimate_quantile(inputs: &[Series], kwargs: MergeTDKwargs) -> PolarsResult { + let tdigest_json: Vec = extract_tdigest_vec(inputs); let tdigests: Vec = tdigest_json.into_iter().map(|td| td.tdigest).collect(); let tdigest = TDigest::merge_digests(tdigests); From e6fde3104a42d041ad92944aa48dd808023c3b77 Mon Sep 17 00:00:00 2001 From: Thomas Lutterbeck Date: Mon, 18 Nov 2024 17:35:45 +0100 Subject: [PATCH 2/2] replaces schema inference with actual schema --- src/expressions.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/expressions.rs b/src/expressions.rs index 19a7bcf..26c1529 100644 --- a/src/expressions.rs +++ b/src/expressions.rs @@ -232,7 +232,7 @@ fn tdigest_cast(inputs: &[Series]) -> PolarsResult { let file = Cursor::new(&td_json); let df = JsonReader::new(file) .with_json_format(JsonFormat::JsonLines) - .infer_schema_len(Some(3)) + .with_schema(Arc::new(Schema::from_iter(tdigest_fields()))) .with_batch_size(NonZeroUsize::new(3).unwrap()) .finish() .unwrap(); @@ -277,7 +277,7 @@ fn merge_tdigests(inputs: &[Series]) -> PolarsResult { let file = Cursor::new(&td_json); let df = JsonReader::new(file) .with_json_format(JsonFormat::JsonLines) - .infer_schema_len(Some(3)) + .with_schema(Arc::new(Schema::from_iter(tdigest_fields()))) .with_batch_size(NonZeroUsize::new(3).unwrap()) .finish() .unwrap();