diff --git a/polars_tdigest/__init__.py b/polars_tdigest/__init__.py index def9445..fc44611 100644 --- a/polars_tdigest/__init__.py +++ b/polars_tdigest/__init__.py @@ -23,13 +23,14 @@ def estimate_median(expr: IntoExpr) -> pl.Expr: ) -def tdigest(expr: IntoExpr) -> pl.Expr: +def tdigest(expr: IntoExpr, max_size: int = 100) -> pl.Expr: return register_plugin_function( plugin_path=Path(__file__).parent, function_name="tdigest", args=expr, is_elementwise=False, returns_scalar=True, + kwargs={"max_size": max_size}, ) @@ -44,11 +45,12 @@ def estimate_quantile(expr: IntoExpr, quantile: float) -> pl.Expr: ) -def tdigest_cast(expr: IntoExpr) -> pl.Expr: +def tdigest_cast(expr: IntoExpr, max_size: int = 100) -> pl.Expr: return register_plugin_function( plugin_path=Path(__file__).parent, function_name="tdigest_cast", args=expr, is_elementwise=False, returns_scalar=True, + kwargs={"max_size": max_size}, ) diff --git a/src/expressions.rs b/src/expressions.rs index 70b2be8..d8411b6 100644 --- a/src/expressions.rs +++ b/src/expressions.rs @@ -14,54 +14,6 @@ use std::io::Cursor; use std::num::NonZeroUsize; use tdigest::TDigest; -// mod tdigest; - -#[polars_expr(output_type=String)] -fn pig_latinnify(inputs: &[Series]) -> PolarsResult { - let ca: &StringChunked = inputs[0].str()?; - let out: StringChunked = ca.apply_to_buffer(|value: &str, output: &mut String| { - if let Some(first_char) = value.chars().next() { - write!(output, "{}{}ay", &value[1..], first_char).unwrap() - } - }); - Ok(out.into_series()) -} - -fn same_output_type(input_fields: &[Field]) -> PolarsResult { - let field = &input_fields[0]; - Ok(field.clone()) -} - -#[polars_expr(output_type_func=same_output_type)] -fn noop(inputs: &[Series]) -> PolarsResult { - let s = &inputs[0]; - Ok(s.clone()) -} - -// TODO estimate median should also work on t-digest and be a shortcut for estimate_quantile with quantile=0.5 -#[polars_expr(output_type=Int64)] -fn estimate_median(inputs: &[Series]) -> PolarsResult { - let values = &inputs[0].i64()?; - let t = TDigest::new_with_size(100); - let chunks: Vec = values - .downcast_iter() - .map(|chunk| { - let array = chunk.as_any().downcast_ref::().unwrap(); - let val_vec = array - .values() - .iter() - .filter_map(|v| Ok(Some(*v as f64)).transpose()) - .collect::, Vec>>(); - t.merge_unsorted(val_vec.unwrap().to_owned()) - }) - .collect(); - - let t_global = TDigest::merge_digests(chunks); - let ans = t_global.estimate_quantile(0.5); - - Ok(Series::new("", vec![ans])) -} - fn tdigest_output(_: &[Field]) -> PolarsResult { Ok(Field::new("tdigest", DataType::Struct(tdigest_fields()))) } @@ -83,29 +35,9 @@ fn tdigest_fields() -> Vec { ] } -// fn tidgest_compute(values: &ChunkedArray) -> Vec { -// let chunks: Vec = POOL.install(|| { -// values -// .downcast_iter() -// .par_bridge() -// .map(|chunk| { -// let t = TDigest::new_with_size(100); -// let array = chunk.as_any().downcast_ref::>().unwrap(); -// let val_vec: Vec = array -// .values() -// .iter() -// .filter_map(|v| Some(*v. as f64)) -// .collect(); -// t.merge_unsorted(val_vec.to_owned()) -// }) -// .collect::>() -// }); -// chunks -// } - // Todo support other numerical types #[polars_expr(output_type_func=tdigest_output)] -fn tdigest(inputs: &[Series]) -> PolarsResult { +fn tdigest(inputs: &[Series], kwargs: TDigestKwargs) -> PolarsResult { let series = &inputs[0]; // TODO: pooling is not feasible on small datasets let chunks = match series.dtype() { @@ -116,7 +48,7 @@ fn tdigest(inputs: &[Series]) -> PolarsResult { .downcast_iter() .par_bridge() .map(|chunk| { - let t = TDigest::new_with_size(100); + let t = TDigest::new_with_size(kwargs.max_size); let array = chunk.as_any().downcast_ref::().unwrap(); let val_vec: Vec = array.non_null_values_iter().collect(); t.merge_unsorted(val_vec.to_owned()) @@ -199,7 +131,7 @@ fn tdigest(inputs: &[Series]) -> PolarsResult { } #[polars_expr(output_type_func=tdigest_output)] -fn tdigest_cast(inputs: &[Series]) -> PolarsResult { +fn tdigest_cast(inputs: &[Series], kwargs: TDigestKwargs) -> PolarsResult { let supported_dtypes = &[ DataType::Float64, DataType::Float32, @@ -218,7 +150,7 @@ fn tdigest_cast(inputs: &[Series]) -> PolarsResult { .downcast_iter() .par_bridge() .map(|chunk| { - let t = TDigest::new_with_size(100); + let t = TDigest::new_with_size(kwargs.max_size); let array = chunk.as_any().downcast_ref::().unwrap(); t.merge_unsorted(array.values().to_vec()) }) @@ -250,7 +182,12 @@ struct MergeTDKwargs { quantile: f64, } -// TODO this should check the type of the series and also work on series of Type f64 +#[derive(Debug, Deserialize)] +struct TDigestKwargs { + max_size: usize, +} + +// TODO this should check the type of the series and also work on series other than of type f64 #[polars_expr(output_type=Float64)] fn estimate_quantile(inputs: &[Series], kwargs: MergeTDKwargs) -> PolarsResult { let mut df = inputs[0].clone().into_frame(); diff --git a/tdigest_yellow_taxi.ipynb b/tdigest_yellow_taxi.ipynb index 3a1cc4f..8d2ad94 100644 --- a/tdigest_yellow_taxi.ipynb +++ b/tdigest_yellow_taxi.ipynb @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -46,33 +46,33 @@ "output_type": "stream", "text": [ "Dataset yellow_tripdata_2024-02.parquet has 3007526 rows\n", - "TDigest took: 118 ms\n", + "TDigest took: 119 ms\n", + "TDigest took: 108 ms\n", + "TDigest took: 108 ms\n", + "TDigest took: 109 ms\n", + "TDigest took: 108 ms\n", + "TDigest with cast took: 109 ms\n", + "TDigest with cast took: 108 ms\n", + "TDigest with cast took: 108 ms\n", + "TDigest with cast took: 109 ms\n", + "TDigest with cast took: 107 ms\n", + "Dataset yellow_tripdata_2024-03.parquet has 3582628 rows\n", + "TDigest took: 136 ms\n", + "TDigest took: 116 ms\n", "TDigest took: 113 ms\n", "TDigest took: 112 ms\n", - "TDigest took: 110 ms\n", - "TDigest took: 109 ms\n", - "TDigest with cast took: 115 ms\n", - "TDigest with cast took: 110 ms\n", + "TDigest took: 112 ms\n", "TDigest with cast took: 111 ms\n", "TDigest with cast took: 111 ms\n", "TDigest with cast took: 111 ms\n", - "Dataset yellow_tripdata_2024-03.parquet has 3582628 rows\n", - "TDigest took: 115 ms\n", - "TDigest took: 114 ms\n", - "TDigest took: 115 ms\n", - "TDigest took: 115 ms\n", - "TDigest took: 115 ms\n", - "TDigest with cast took: 118 ms\n", - "TDigest with cast took: 116 ms\n", - "TDigest with cast took: 116 ms\n", - "TDigest with cast took: 116 ms\n", - "TDigest with cast took: 117 ms\n", - "Estimate median took: 1 ms\n", + "TDigest with cast took: 112 ms\n", + "TDigest with cast took: 112 ms\n", + "Estimate median took: 0 ms\n", "Estimate median took: 0 ms\n", "Estimate median took: 0 ms\n", "Estimate median took: 0 ms\n", "Estimate median took: 0 ms\n", - "Estimated median = 1.0\n" + "Estimated median = 1.7201926614756424\n" ] } ], @@ -83,7 +83,8 @@ "datasets = []\n", "tdigests = []\n", "run_performance_test = True\n", - "numeric_col = \"passenger_count\"\n", + "numeric_col = \"trip_distance\"\n", + "max_size = 100\n", "\n", "for dataset in dataset_files:\n", " local_file = f\"{local_folder}{dataset}\"\n", @@ -94,8 +95,8 @@ "\n", " print(f\"Dataset {dataset} has {df.select(pl.len()).collect().item()} rows\")\n", "\n", - " query = df.select(tdigest(numeric_col))\n", - " query_cast = df.select(tdigest_cast(numeric_col))\n", + " query = df.select(tdigest(numeric_col, max_size))\n", + " query_cast = df.select(tdigest_cast(numeric_col, max_size))\n", " if run_performance_test:\n", " for _ in range(5):\n", " with Timer(text=\"TDigest took: {milliseconds:.0f} ms\"):\n", @@ -118,19 +119,19 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Median took: 48 ms\n", - "Median took: 42 ms\n", - "Median took: 43 ms\n", - "Median took: 45 ms\n", - "Median took: 41 ms\n", - "Median = 1.0\n" + "Median took: 36 ms\n", + "Median took: 28 ms\n", + "Median took: 28 ms\n", + "Median took: 30 ms\n", + "Median took: 29 ms\n", + "Median = 1.71\n" ] } ], @@ -149,23 +150,23 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Median on partition took: 23 ms\n", - "Median on partition took: 23 ms\n", - "Median on partition took: 22 ms\n", + "Median on partition took: 20 ms\n", + "Median on partition took: 20 ms\n", + "Median on partition took: 19 ms\n", + "Median on partition took: 19 ms\n", + "Median on partition took: 19 ms\n", "Median on partition took: 21 ms\n", - "Median on partition took: 23 ms\n", - "Median on partition took: 23 ms\n", - "Median on partition took: 24 ms\n", - "Median on partition took: 22 ms\n", - "Median on partition took: 23 ms\n", - "Median on partition took: 24 ms\n" + "Median on partition took: 21 ms\n", + "Median on partition took: 20 ms\n", + "Median on partition took: 20 ms\n", + "Median on partition took: 20 ms\n" ] } ], @@ -175,18 +176,11 @@ " with Timer(text=\"Median on partition took: {milliseconds:.0f} ms\"):\n", " partition.select(col(numeric_col).median()).collect()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": ".env", + "display_name": ".venv", "language": "python", "name": "python3" },