Skip to content

Commit

Permalink
added max_size parameter and cleaned up unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
thobai committed Nov 21, 2024
1 parent fabb319 commit 5438c58
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 122 deletions.
6 changes: 4 additions & 2 deletions polars_tdigest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ def estimate_median(expr: IntoExpr) -> pl.Expr:
)


def tdigest(expr: IntoExpr) -> pl.Expr:
def tdigest(expr: IntoExpr, max_size: int = 100) -> pl.Expr:
return register_plugin_function(
plugin_path=Path(__file__).parent,
function_name="tdigest",
args=expr,
is_elementwise=False,
returns_scalar=True,
kwargs={"max_size": max_size},
)


Expand All @@ -44,11 +45,12 @@ def estimate_quantile(expr: IntoExpr, quantile: float) -> pl.Expr:
)


def tdigest_cast(expr: IntoExpr) -> pl.Expr:
def tdigest_cast(expr: IntoExpr, max_size: int = 100) -> pl.Expr:
return register_plugin_function(
plugin_path=Path(__file__).parent,
function_name="tdigest_cast",
args=expr,
is_elementwise=False,
returns_scalar=True,
kwargs={"max_size": max_size},
)
83 changes: 10 additions & 73 deletions src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,54 +14,6 @@ use std::io::Cursor;
use std::num::NonZeroUsize;
use tdigest::TDigest;

// mod tdigest;

#[polars_expr(output_type=String)]
fn pig_latinnify(inputs: &[Series]) -> PolarsResult<Series> {
let ca: &StringChunked = inputs[0].str()?;
let out: StringChunked = ca.apply_to_buffer(|value: &str, output: &mut String| {
if let Some(first_char) = value.chars().next() {
write!(output, "{}{}ay", &value[1..], first_char).unwrap()
}
});
Ok(out.into_series())
}

fn same_output_type(input_fields: &[Field]) -> PolarsResult<Field> {
let field = &input_fields[0];
Ok(field.clone())
}

#[polars_expr(output_type_func=same_output_type)]
fn noop(inputs: &[Series]) -> PolarsResult<Series> {
let s = &inputs[0];
Ok(s.clone())
}

// TODO estimate median should also work on t-digest and be a shortcut for estimate_quantile with quantile=0.5
#[polars_expr(output_type=Int64)]
fn estimate_median(inputs: &[Series]) -> PolarsResult<Series> {
let values = &inputs[0].i64()?;
let t = TDigest::new_with_size(100);
let chunks: Vec<TDigest> = values
.downcast_iter()
.map(|chunk| {
let array = chunk.as_any().downcast_ref::<Int64Array>().unwrap();
let val_vec = array
.values()
.iter()
.filter_map(|v| Ok(Some(*v as f64)).transpose())
.collect::<Result<Vec<f64>, Vec<f64>>>();
t.merge_unsorted(val_vec.unwrap().to_owned())
})
.collect();

let t_global = TDigest::merge_digests(chunks);
let ans = t_global.estimate_quantile(0.5);

Ok(Series::new("", vec![ans]))
}

fn tdigest_output(_: &[Field]) -> PolarsResult<Field> {
Ok(Field::new("tdigest", DataType::Struct(tdigest_fields())))
}
Expand All @@ -83,29 +35,9 @@ fn tdigest_fields() -> Vec<Field> {
]
}

// fn tidgest_compute<T: NumericNative, PDT: PolarsDataType>(values: &ChunkedArray<PDT>) -> Vec<TDigest> {
// let chunks: Vec<TDigest> = POOL.install(|| {
// values
// .downcast_iter()
// .par_bridge()
// .map(|chunk| {
// let t = TDigest::new_with_size(100);
// let array = chunk.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
// let val_vec: Vec<f64> = array
// .values()
// .iter()
// .filter_map(|v| Some(*v. as f64))
// .collect();
// t.merge_unsorted(val_vec.to_owned())
// })
// .collect::<Vec<TDigest>>()
// });
// chunks
// }

// Todo support other numerical types
#[polars_expr(output_type_func=tdigest_output)]
fn tdigest(inputs: &[Series]) -> PolarsResult<Series> {
fn tdigest(inputs: &[Series], kwargs: TDigestKwargs) -> PolarsResult<Series> {
let series = &inputs[0];
// TODO: pooling is not feasible on small datasets
let chunks = match series.dtype() {
Expand All @@ -116,7 +48,7 @@ fn tdigest(inputs: &[Series]) -> PolarsResult<Series> {
.downcast_iter()
.par_bridge()
.map(|chunk| {
let t = TDigest::new_with_size(100);
let t = TDigest::new_with_size(kwargs.max_size);
let array = chunk.as_any().downcast_ref::<Float64Array>().unwrap();
let val_vec: Vec<f64> = array.non_null_values_iter().collect();
t.merge_unsorted(val_vec.to_owned())
Expand Down Expand Up @@ -199,7 +131,7 @@ fn tdigest(inputs: &[Series]) -> PolarsResult<Series> {
}

#[polars_expr(output_type_func=tdigest_output)]
fn tdigest_cast(inputs: &[Series]) -> PolarsResult<Series> {
fn tdigest_cast(inputs: &[Series], kwargs: TDigestKwargs) -> PolarsResult<Series> {
let supported_dtypes = &[
DataType::Float64,
DataType::Float32,
Expand All @@ -218,7 +150,7 @@ fn tdigest_cast(inputs: &[Series]) -> PolarsResult<Series> {
.downcast_iter()
.par_bridge()
.map(|chunk| {
let t = TDigest::new_with_size(100);
let t = TDigest::new_with_size(kwargs.max_size);
let array = chunk.as_any().downcast_ref::<Float64Array>().unwrap();
t.merge_unsorted(array.values().to_vec())
})
Expand Down Expand Up @@ -250,7 +182,12 @@ struct MergeTDKwargs {
quantile: f64,
}

// TODO this should check the type of the series and also work on series of Type f64
#[derive(Debug, Deserialize)]
struct TDigestKwargs {
max_size: usize,
}

// TODO this should check the type of the series and also work on series other than of type f64
#[polars_expr(output_type=Float64)]
fn estimate_quantile(inputs: &[Series], kwargs: MergeTDKwargs) -> PolarsResult<Series> {
let mut df = inputs[0].clone().into_frame();
Expand Down
88 changes: 41 additions & 47 deletions tdigest_yellow_taxi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,41 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset yellow_tripdata_2024-02.parquet has 3007526 rows\n",
"TDigest took: 118 ms\n",
"TDigest took: 119 ms\n",
"TDigest took: 108 ms\n",
"TDigest took: 108 ms\n",
"TDigest took: 109 ms\n",
"TDigest took: 108 ms\n",
"TDigest with cast took: 109 ms\n",
"TDigest with cast took: 108 ms\n",
"TDigest with cast took: 108 ms\n",
"TDigest with cast took: 109 ms\n",
"TDigest with cast took: 107 ms\n",
"Dataset yellow_tripdata_2024-03.parquet has 3582628 rows\n",
"TDigest took: 136 ms\n",
"TDigest took: 116 ms\n",
"TDigest took: 113 ms\n",
"TDigest took: 112 ms\n",
"TDigest took: 110 ms\n",
"TDigest took: 109 ms\n",
"TDigest with cast took: 115 ms\n",
"TDigest with cast took: 110 ms\n",
"TDigest took: 112 ms\n",
"TDigest with cast took: 111 ms\n",
"TDigest with cast took: 111 ms\n",
"TDigest with cast took: 111 ms\n",
"Dataset yellow_tripdata_2024-03.parquet has 3582628 rows\n",
"TDigest took: 115 ms\n",
"TDigest took: 114 ms\n",
"TDigest took: 115 ms\n",
"TDigest took: 115 ms\n",
"TDigest took: 115 ms\n",
"TDigest with cast took: 118 ms\n",
"TDigest with cast took: 116 ms\n",
"TDigest with cast took: 116 ms\n",
"TDigest with cast took: 116 ms\n",
"TDigest with cast took: 117 ms\n",
"Estimate median took: 1 ms\n",
"TDigest with cast took: 112 ms\n",
"TDigest with cast took: 112 ms\n",
"Estimate median took: 0 ms\n",
"Estimate median took: 0 ms\n",
"Estimate median took: 0 ms\n",
"Estimate median took: 0 ms\n",
"Estimate median took: 0 ms\n",
"Estimated median = 1.0\n"
"Estimated median = 1.7201926614756424\n"
]
}
],
Expand All @@ -83,7 +83,8 @@
"datasets = []\n",
"tdigests = []\n",
"run_performance_test = True\n",
"numeric_col = \"passenger_count\"\n",
"numeric_col = \"trip_distance\"\n",
"max_size = 100\n",
"\n",
"for dataset in dataset_files:\n",
" local_file = f\"{local_folder}{dataset}\"\n",
Expand All @@ -94,8 +95,8 @@
"\n",
" print(f\"Dataset {dataset} has {df.select(pl.len()).collect().item()} rows\")\n",
"\n",
" query = df.select(tdigest(numeric_col))\n",
" query_cast = df.select(tdigest_cast(numeric_col))\n",
" query = df.select(tdigest(numeric_col, max_size))\n",
" query_cast = df.select(tdigest_cast(numeric_col, max_size))\n",
" if run_performance_test:\n",
" for _ in range(5):\n",
" with Timer(text=\"TDigest took: {milliseconds:.0f} ms\"):\n",
Expand All @@ -118,19 +119,19 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Median took: 48 ms\n",
"Median took: 42 ms\n",
"Median took: 43 ms\n",
"Median took: 45 ms\n",
"Median took: 41 ms\n",
"Median = 1.0\n"
"Median took: 36 ms\n",
"Median took: 28 ms\n",
"Median took: 28 ms\n",
"Median took: 30 ms\n",
"Median took: 29 ms\n",
"Median = 1.71\n"
]
}
],
Expand All @@ -149,23 +150,23 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Median on partition took: 23 ms\n",
"Median on partition took: 23 ms\n",
"Median on partition took: 22 ms\n",
"Median on partition took: 20 ms\n",
"Median on partition took: 20 ms\n",
"Median on partition took: 19 ms\n",
"Median on partition took: 19 ms\n",
"Median on partition took: 19 ms\n",
"Median on partition took: 21 ms\n",
"Median on partition took: 23 ms\n",
"Median on partition took: 23 ms\n",
"Median on partition took: 24 ms\n",
"Median on partition took: 22 ms\n",
"Median on partition took: 23 ms\n",
"Median on partition took: 24 ms\n"
"Median on partition took: 21 ms\n",
"Median on partition took: 20 ms\n",
"Median on partition took: 20 ms\n",
"Median on partition took: 20 ms\n"
]
}
],
Expand All @@ -175,18 +176,11 @@
" with Timer(text=\"Median on partition took: {milliseconds:.0f} ms\"):\n",
" partition.select(col(numeric_col).median()).collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand Down

0 comments on commit 5438c58

Please sign in to comment.