Skip to content

Commit

Permalink
added function to merge tdigest with small refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
thobai committed Nov 18, 2024
1 parent 43f3bca commit 41a8668
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
10 changes: 10 additions & 0 deletions polars_tdigest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,13 @@ def tdigest_cast(expr: IntoExpr) -> pl.Expr:
is_elementwise=False,
returns_scalar=True,
)


def merge_tdigests(expr: IntoExpr) -> pl.Expr:
return register_plugin_function(
plugin_path=Path(__file__).parent,
function_name="merge_tdigests",
args=expr,
is_elementwise=False,
returns_scalar=True,
)
35 changes: 30 additions & 5 deletions src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,7 @@ struct MergeTDKwargs {
quantile: f64,
}

// TODO this should check the type of the series and also work on series of Type f64
#[polars_expr(output_type=Float64)]
fn estimate_quantile(inputs: &[Series], kwargs: MergeTDKwargs) -> PolarsResult<Series> {
fn extract_tdigest_vec(inputs: &[Series]) -> Vec<TDigestCol> {
let mut df = inputs[0].clone().into_frame();
df.set_column_names(vec!["tdigest"].as_slice()).unwrap();
let mut buf = BufWriter::new(Vec::new());
Expand All @@ -262,8 +260,35 @@ fn estimate_quantile(inputs: &[Series], kwargs: MergeTDKwargs) -> PolarsResult<S

let bytes = buf.into_inner().unwrap();
let json_str = String::from_utf8(bytes).unwrap();
let tdigest_json: Vec<TDigestCol> =
serde_json::from_str(&json_str).expect("Failed to parse the tdigest JSON string");

serde_json::from_str(&json_str).expect("Failed to parse the tdigest JSON string")
}

#[polars_expr(output_type_func=tdigest_output)]
fn merge_tdigests(inputs: &[Series]) -> PolarsResult<Series> {
let series = &inputs[0];
let tdigest_json: Vec<TDigestCol> = extract_tdigest_vec(inputs);

let tdigests: Vec<TDigest> = tdigest_json.into_iter().map(|td| td.tdigest).collect();
let tdigest = TDigest::merge_digests(tdigests);

let td_json = serde_json::to_string(&tdigest).unwrap();

let file = Cursor::new(&td_json);
let df = JsonReader::new(file)
.with_json_format(JsonFormat::JsonLines)
.infer_schema_len(Some(3))
.with_batch_size(NonZeroUsize::new(3).unwrap())
.finish()
.unwrap();

Ok(df.into_struct(series.name()).into_series())
}

// TODO this should check the type of the series and also work on series of Type f64
#[polars_expr(output_type=Float64)]
fn estimate_quantile(inputs: &[Series], kwargs: MergeTDKwargs) -> PolarsResult<Series> {
let tdigest_json: Vec<TDigestCol> = extract_tdigest_vec(inputs);

let tdigests: Vec<TDigest> = tdigest_json.into_iter().map(|td| td.tdigest).collect();
let tdigest = TDigest::merge_digests(tdigests);
Expand Down

0 comments on commit 41a8668

Please sign in to comment.