diff --git a/Cargo.lock b/Cargo.lock index 654dedd..8a9af71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,15 +57,16 @@ checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "arrow" -version = "6.5.0" +version = "8.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "216c6846a292bdd93c2b93c1baab58c32ff50e2ab5e8d50db333ab518535dd8b" +checksum = "ce240772a007c63658c1d335bb424fd1019b87895dee899b7bf70e85b2d24e5f" dependencies = [ "bitflags", "chrono", "comfy-table", "csv", "flatbuffers", + "half", "hex", "indexmap", "lazy_static", @@ -111,13 +112,11 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "blake2" -version = "0.9.2" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e37d16930f5459780f5621038b6382b9bb37c19016f39fb6b5808d831f174" +checksum = "b94ba84325db59637ffc528bbe8c7f86c02c57cff5c0e2b9b00f9a851f42f309" dependencies = [ - "crypto-mac", - "digest", - "opaque-debug", + "digest 0.10.1", ] [[package]] @@ -131,14 +130,14 @@ dependencies = [ "cc", "cfg-if", "constant_time_eq", - "digest", + "digest 0.9.0", ] [[package]] name = "block-buffer" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +checksum = "03588e54c62ae6d763e2a80090d50353b785795361b4ff5b3bf0a5097fc31c0b" dependencies = [ "generic-array", ] @@ -206,7 +205,6 @@ dependencies = [ "libc", "num-integer", "num-traits", - "time", "winapi", ] @@ -246,13 +244,12 @@ dependencies = [ ] [[package]] -name = "crypto-mac" -version = "0.8.0" +name = "crypto-common" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" +checksum = "683d6b536309245c849479fba3da410962a43ed8e51c26b729208ec0ac2798d0" dependencies = [ "generic-array", - "subtle", ] [[package]] @@ -280,8 +277,7 @@ dependencies = [ [[package]] name = "datafusion" version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e4a8a1f1ee057b2c27a01f050b9dffe56e8d43605d0201234b353a3cc1eb2f" +source = "git+https://github.com/apache/arrow-datafusion.git#15cfcbc28305e82891a5a52d252fb23c72fd8458" dependencies = [ "ahash", "arrow", @@ -290,12 +286,13 @@ dependencies = [ "blake3", "chrono", "futures", - "hashbrown", + "hashbrown 0.12.0", "lazy_static", "log", "md-5", "num_cpus", "ordered-float 2.10.0", + "parking_lot 0.12.0", "parquet", "paste 1.0.6", "pin-project-lite", @@ -305,6 +302,7 @@ dependencies = [ "sha2", "smallvec", "sqlparser", + "tempfile", "tokio", "tokio-stream", "unicode-segmentation", @@ -314,7 +312,9 @@ dependencies = [ name = "datafusion-python" version = "0.4.0" dependencies = [ + "async-trait", "datafusion", + "futures", "pyo3", "rand 0.7.3", "tokio", @@ -330,6 +330,27 @@ dependencies = [ "generic-array", ] +[[package]] +name = "digest" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b697d66081d42af4fba142d56918a3cb21dc8eb63372c6b85d14f44fb9c5979b" +dependencies = [ + "block-buffer", + "crypto-common", + "generic-array", + "subtle", +] + +[[package]] +name = "fastrand" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3fcf0cee53519c866c09b5de1f6c56ff9d647101f81c1964fa632e148896cdf" +dependencies = [ + "instant", +] + [[package]] name = "flatbuffers" version = "2.0.0" @@ -474,11 +495,23 @@ dependencies = [ "wasi 0.10.2+wasi-snapshot-preview1", ] +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + [[package]] name = "hashbrown" version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" + +[[package]] +name = "hashbrown" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c21d40587b92fa6a6c6e3c1bdbf87d75511db5672f9c93175574b3a00df1758" dependencies = [ "ahash", ] @@ -514,7 +547,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc633605454125dec4b66843673f01c7df2b89479b32e0ed634e43a91cff62a5" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.11.2", ] [[package]] @@ -654,9 +687,9 @@ checksum = "1b03d17f364a3a042d5e5d46b053bbbf82c92c9430c592dd4c064dc6ee997125" [[package]] name = "lock_api" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712a4d093c9976e24e7dbca41db895dabcbac38eb5f4045393d17a95bdfb1109" +checksum = "88943dd7ef4a2e5a4bfa2753aaab3013e34ce2533d1996fb18ef591e315e2b3b" dependencies = [ "scopeguard", ] @@ -692,13 +725,11 @@ dependencies = [ [[package]] name = "md-5" -version = "0.9.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15" +checksum = "e6a38fc55c8bbc10058782919516f88826e70320db6d206aebc49611d24216ae" dependencies = [ - "block-buffer", - "digest", - "opaque-debug", + "digest 0.10.1", ] [[package]] @@ -829,12 +860,6 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da32515d9f6e6e489d7bc9d84c71b060db7247dc035bbe44eac88cf87486d8d5" -[[package]] -name = "opaque-debug" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" - [[package]] name = "ordered-float" version = "1.1.1" @@ -861,7 +886,17 @@ checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" dependencies = [ "instant", "lock_api", - "parking_lot_core", + "parking_lot_core 0.8.5", +] + +[[package]] +name = "parking_lot" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58" +dependencies = [ + "lock_api", + "parking_lot_core 0.9.0", ] [[package]] @@ -878,11 +913,24 @@ dependencies = [ "winapi", ] +[[package]] +name = "parking_lot_core" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2f4f894f3865f6c0e02810fc597300f34dc2510f66400da262d8ae10e75767d" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + [[package]] name = "parquet" -version = "6.5.0" +version = "8.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "788d9953f4cfbe9db1beff7bebd54299d105e34680d78b82b1ddc85d432cac9d" +checksum = "2d5a6492e0b849fd458bc9364aee4c8a9882b3cc21b2576767162725f69d2ad8" dependencies = [ "arrow", "base64", @@ -901,9 +949,9 @@ dependencies = [ [[package]] name = "parquet-format" -version = "2.6.1" +version = "4.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5bc6b23543b5dedc8f6cce50758a35e5582e148e0cfa26bd0cacd569cda5b71" +checksum = "1f0c06cdcd5460967c485f9c40a821746f5955ad81990533c7fae95dbd9bc0b5" dependencies = [ "thrift", ] @@ -968,14 +1016,14 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.14.5" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35100f9347670a566a67aa623369293703322bb9db77d99d7df7313b575ae0c8" +checksum = "7cf01dbf1c05af0a14c7779ed6f3aa9deac9c3419606ac9de537a2d649005720" dependencies = [ "cfg-if", "indoc", "libc", - "parking_lot", + "parking_lot 0.11.2", "paste 0.1.18", "pyo3-build-config", "pyo3-macros", @@ -984,18 +1032,18 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.14.5" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d12961738cacbd7f91b7c43bc25cfeeaa2698ad07a04b3be0aa88b950865738f" +checksum = "dbf9e4d128bfbddc898ad3409900080d8d5095c379632fbbfbb9c8cfb1fb852b" dependencies = [ "once_cell", ] [[package]] name = "pyo3-macros" -version = "0.14.5" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0bc5215d704824dfddddc03f93cb572e1155c68b6761c37005e1c288808ea8" +checksum = "67701eb32b1f9a9722b4bc54b548ff9d7ebfded011c12daece7b9063be1fd755" dependencies = [ "pyo3-macros-backend", "quote", @@ -1004,9 +1052,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.14.5" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71623fc593224afaab918aa3afcaf86ed2f43d34f6afde7f3922608f253240df" +checksum = "f44f09e825ee49a105f2c7b23ebee50886a9aee0746f4dd5a704138a64b0218a" dependencies = [ "proc-macro2", "pyo3-build-config", @@ -1136,6 +1184,15 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" +[[package]] +name = "remove_dir_all" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" +dependencies = [ + "winapi", +] + [[package]] name = "ryu" version = "1.0.9" @@ -1179,15 +1236,13 @@ dependencies = [ [[package]] name = "sha2" -version = "0.9.8" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b69f9a4c9740d74c5baa3fd2e547f9525fa8088a8a958e0ca2409a514e33f5fa" +checksum = "99c3bd8169c58782adad9290a9af5939994036b76187f7b4f0e6de91dbbfc0ec" dependencies = [ - "block-buffer", "cfg-if", "cpufeatures", - "digest", - "opaque-debug", + "digest 0.10.1", ] [[package]] @@ -1210,9 +1265,9 @@ checksum = "45456094d1983e2ee2a18fdfebce3189fa451699d0502cb8e3b49dba5ba41451" [[package]] name = "sqlparser" -version = "0.12.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "760e624412a15d5838ae04fad01037beeff1047781431d74360cddd6b3c1c784" +checksum = "b9907f54bd0f7b6ce72c2be1e570a614819ee08e3deb66d90480df341d8a12a8" dependencies = [ "log", ] @@ -1258,6 +1313,20 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "tempfile" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" +dependencies = [ + "cfg-if", + "fastrand", + "libc", + "redox_syscall", + "remove_dir_all", + "winapi", +] + [[package]] name = "thiserror" version = "1.0.30" @@ -1300,16 +1369,6 @@ dependencies = [ "threadpool", ] -[[package]] -name = "time" -version = "0.1.43" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "tokio" version = "1.15.0" @@ -1317,6 +1376,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbbf1c778ec206785635ce8ad57fe52b3009ae9e0c9f574a728f3049d3e55838" dependencies = [ "num_cpus", + "parking_lot 0.11.2", "pin-project-lite", "tokio-macros", ] @@ -1422,6 +1482,49 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ceb069ac8b2117d36924190469735767f0990833935ab430155e71a44bafe148" +dependencies = [ + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_msvc" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d027175d00b01e0cbeb97d6ab6ebe03b12330a35786cbaca5252b1c4bf5d9b" + +[[package]] +name = "windows_i686_gnu" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8793f59f7b8e8b01eda1a652b2697d87b93097198ae85f823b969ca5b89bba58" + +[[package]] +name = "windows_i686_msvc" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8602f6c418b67024be2996c512f5f995de3ba417f4c75af68401ab8756796ae4" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3d615f419543e0bd7d2b3323af0d86ff19cbc4f816e6453f36a2c2ce889c354" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d95421d9ed3672c280884da53201a5c46b7b2765ca6faf34b0d71cf34a3561" + [[package]] name = "zstd" version = "0.9.1+zstd.1.5.1" diff --git a/Cargo.toml b/Cargo.toml index aa16236..b80d124 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,11 +28,25 @@ edition = "2021" rust-version = "1.57" [dependencies] -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } +tokio = { version = "1.0", features = [ + "macros", + "rt", + "rt-multi-thread", + "sync", +] } rand = "0.7" -pyo3 = { version = "0.14", features = ["extension-module", "abi3", "abi3-py36"] } -datafusion = { version = "6.0.0", features = ["pyarrow"] } +pyo3 = { version = "0.15", features = [ + "extension-module", + "abi3", + "abi3-py36", +] } +# datafusion = { version = "6.0.0", features = ["pyarrow"] } +datafusion = { git = "https://github.com/apache/arrow-datafusion.git", features = [ + "pyarrow", +] } uuid = { version = "0.8", features = ["v4"] } +async-trait = "0.1.41" +futures = "0.3" [lib] name = "_internal" diff --git a/datafusion/tests/test_pyarrow_dataset.py b/datafusion/tests/test_pyarrow_dataset.py new file mode 100644 index 0000000..5aafabc --- /dev/null +++ b/datafusion/tests/test_pyarrow_dataset.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import date, timedelta +from tempfile import mkdtemp + +import pyarrow as pa +import pyarrow.dataset as ds +import pytest + +from datafusion import ExecutionContext + + +@pytest.fixture +def ctx(): + return ExecutionContext() + + +@pytest.fixture +def table(): + table = pa.table({ + 'z': pa.array([x / 3 for x in range(8)]), + 'x': pa.array(['a'] * 3 + ['b'] * 5), + 'y': pa.array([date(2020, 1, 1) + timedelta(days=x) for x in range(8)]), + }) + return table + + +@pytest.fixture +def dataset(ctx, table): + tmp_dir = mkdtemp() + + part = ds.partitioning( + pa.schema([('x', pa.string()), ('y', pa.date32())]), + flavor="hive", + ) + + ds.write_dataset(table, tmp_dir, partitioning=part, format="parquet") + + dataset = ds.dataset(tmp_dir, partitioning=part) + ctx.register_dataset("ds", dataset) + return dataset + + +def test_catalog(ctx, table, dataset): + catalog_table = ctx.catalog().database().table("ds") + assert catalog_table.kind == "physical" + assert catalog_table.schema == table.schema + + +def test_scan_full(ctx, table, dataset): + result = ctx.sql("SELECT * FROM ds").collect() + assert pa.Table.from_batches(result) == table + + +def test_dataset_filter(ctx: ExecutionContext, table: pa.Table, dataset): + result = ctx.sql("SELECT * FROM ds WHERE y BETWEEN 2020-01-02 AND 2020-01-06 AND x = 'b'").collect() + assert result.record_count() == 3 + + +def test_dataset_project(ctx: ExecutionContext, table: pa.Table, dataset): + result = ctx.sql("SELECT z, y FROM ds").collect() + assert result.col_names() == ['z', 'y'] diff --git a/src/context.rs b/src/context.rs index 7f386ba..0672d35 100644 --- a/src/context.rs +++ b/src/context.rs @@ -31,6 +31,7 @@ use datafusion::prelude::CsvReadOptions; use crate::catalog::PyCatalog; use crate::dataframe::PyDataFrame; +use crate::dataset::PyArrowDatasetTable; use crate::errors::DataFusionError; use crate::udf::PyScalarUDF; use crate::utils::wait_for_future; @@ -60,10 +61,7 @@ impl PyExecutionContext { Ok(PyDataFrame::new(df)) } - fn create_dataframe( - &mut self, - partitions: Vec>, - ) -> PyResult { + fn create_dataframe(&mut self, partitions: Vec>) -> PyResult { let table = MemTable::try_new(partitions[0][0].schema(), partitions) .map_err(DataFusionError::from)?; @@ -143,6 +141,13 @@ impl PyExecutionContext { Ok(()) } + fn register_dataset(&mut self, name: &str, dataset: PyArrowDatasetTable) -> PyResult<()> { + self.ctx + .register_table(name, Arc::new(dataset)) + .map_err(DataFusionError::from)?; + Ok(()) + } + fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> { self.ctx.register_udf(udf.function); Ok(()) diff --git a/src/dataframe.rs b/src/dataframe.rs index 9050df9..171c53c 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -19,8 +19,10 @@ use std::sync::Arc; use pyo3::prelude::*; +use datafusion::arrow::array::StringArray; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::PyArrowConvert; +use datafusion::arrow::record_batch::RecordBatch; use datafusion::arrow::util::pretty; use datafusion::dataframe::DataFrame; use datafusion::logical_plan::JoinType; @@ -100,6 +102,33 @@ impl PyDataFrame { Ok(pretty::print_batches(&batches)?) } + #[args(verbose = false, analyze = false)] + fn explain(&self, verbose: bool, analyze: bool, py: Python) -> PyResult<()> { + let df = self.df.explain(verbose, analyze)?; + let batches = wait_for_future(py, df.collect())?; + let batch = RecordBatch::concat(&batches[0].schema(), &batches)?; + + let plan_types = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("Plan types is not a String anymore"); + let plans = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("Plan is not a String anymore"); + + for (plan_type, plan) in plan_types.iter().zip(plans.iter()) { + if plan_type.is_some() && plan.is_some() { + println!("{}", plan_type.unwrap()); + println!("{}", plan.unwrap()); + } + } + + Ok(()) + } + fn join( &self, right: PyDataFrame, diff --git a/src/dataset.rs b/src/dataset.rs new file mode 100644 index 0000000..b7935ce --- /dev/null +++ b/src/dataset.rs @@ -0,0 +1,486 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use async_trait::async_trait; +use datafusion::arrow::datatypes::Schema; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::error::{ArrowError, Result as ArrowResult}; +use datafusion::arrow::pyarrow::PyArrowConvert; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::datasource::TableProviderFilterPushDown; +use datafusion::datasource::TableProvider; +use datafusion::error::DataFusionError; +use datafusion::error::Result; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::logical_plan::{Expr, ExpressionVisitor, Operator, Recursion}; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::stream::RecordBatchReceiverStream; +use datafusion::physical_plan::{ + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, +}; +use pyo3::exceptions::{PyAssertionError, PyNotImplementedError, PyStopIteration}; +use pyo3::prelude::*; +use pyo3::types::PyDict; +use std::any::Any; +use std::fmt; +use std::sync::Arc; +use tokio::{ + sync::mpsc::{channel, Receiver, Sender}, + task, +}; + +pub struct PyArrowDatasetTable { + dataset: Py, + schema: SchemaRef, +} + +impl PyArrowDatasetTable { + /// Returns true if expression can by evaluated by pyarrow against this dataset + fn expression_valid(&self, expr: &Expr) -> bool { + if let Ok(pyarrow_expr) = expr_to_pyarrow(expr) { + let res = Python::with_gil(|py| -> PyResult<()> { + let scanner_kwargs = PyDict::new(py); + scanner_kwargs.set_item("filter", pyarrow_expr)?; + self.dataset + .call_method(py, "scanner", (), Some(scanner_kwargs))?; + Ok(()) + }); + res.is_ok() + } else { + false + } + } +} + +impl<'py> FromPyObject<'py> for PyArrowDatasetTable { + fn extract(ob: &'py PyAny) -> PyResult { + // Check it's a PyArrow dataset + // "pyarrow.dataset.FileSystemDataset" + + let dataset: Py = ob.extract()?; + let schema = Python::with_gil(|py| -> PyResult { + Schema::from_pyarrow(dataset.getattr(py, "schema")?.as_ref(py)) + })?; + + Ok(PyArrowDatasetTable { + dataset, + schema: Arc::new(schema), + }) + } +} + +#[async_trait] +impl TableProvider for PyArrowDatasetTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + async fn scan( + &self, + projection: &Option>, + filters: &[Expr], + limit: Option, + ) -> Result> { + // Filtering is only inexact because of expression conversion, but the + // PyArrow scanner does apply all filters given to it. + let combined_filter = filters + .iter() + .filter(|expr| self.expression_valid(expr)) + .map(|f| f.clone()) + .reduce(|acc, item| acc.and(item)); + let scanner = PyArrowDatasetScanner::make( + self.dataset.clone(), + self.schema.clone(), + projection, + combined_filter.clone(), + limit, + 10, // Dummy value; scanner recreated later with runtime batch_size. + ); + + match scanner { + Ok(scanner) => Ok(Arc::new(PyArrowDatasetExec { + dataset: self.dataset.clone(), + scanner, + projection: projection.clone(), + filter: combined_filter, + limit, + schema: self.schema.clone(), + projected_statistics: Statistics::default(), + metrics: ExecutionPlanMetricsSet::new(), + })), + Err(err) => Err(DataFusionError::Execution(err.to_string())), + } + } + + fn supports_filter_pushdown(&self, _: &Expr) -> Result { + Ok(TableProviderFilterPushDown::Inexact) + } +} + +pub struct PyArrowDatasetScanner { + scanner: Arc>, + limit: Option, + schema: SchemaRef, +} + +impl fmt::Debug for PyArrowDatasetScanner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PyArrowDatasetScanner") + .field("scanner", &"pyarrow.dataset.Scanner") + .field("limit", &self.limit) + .field("schema", &self.schema) + .finish() + } +} + +impl Clone for PyArrowDatasetScanner { + fn clone(&self) -> Self { + PyArrowDatasetScanner { + // TODO: Is this a bad way to clone? + scanner: self.scanner.clone(), + limit: self.limit.clone(), + schema: self.schema.clone(), + } + } +} + +impl PyArrowDatasetScanner { + fn make( + dataset: Py, + schema: SchemaRef, + projection: &Option>, + filter: Option, + limit: Option, + batch_size: usize, + ) -> Result { + let scanner = Python::with_gil(|py| -> PyResult> { + let scanner_kwargs = PyDict::new(py); + scanner_kwargs.set_item("batch_size", batch_size)?; + if let Some(expr) = filter { + scanner_kwargs.set_item("filter", expr_to_pyarrow(&expr)?)?; + }; + + if let Some(indices) = projection { + let column_names: Vec = schema + .project(indices)? + .fields() + .iter() + .map(|field| field.name().clone()) + .collect(); + scanner_kwargs.set_item("columns", column_names)?; + } + + Ok(dataset + .call_method(py, "scanner", (), Some(scanner_kwargs))? + .extract(py)?) + }); + match scanner { + Ok(scanner) => Ok(Self { + scanner: Arc::new(scanner), + limit, + schema, + }), + Err(err) => Err(DataFusionError::Execution(err.to_string())), + } + } + + fn projected_schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn get_batches(&self, response_tx: Sender>) -> Result<()> { + let mut count = 0; + + // TODO: Avoid Python GIL with Arrow C Stream interface? + // https://arrow.apache.org/docs/dev/format/CStreamInterface.html + // https://github.com/apache/arrow/blob/cc4e2a54309813e6bbbb36ba50bcd22a7b71d3d9/python/pyarrow/ipc.pxi#L620 + let batch_iter = Python::with_gil(|py| self.scanner.call_method0(py, "to_batches")) + .map_err(|err| DataFusionError::Execution(err.to_string()))?; + + loop { + // TODO: Avoid Python GIL with Arrow C Stream interface? + // https://arrow.apache.org/docs/dev/format/CStreamInterface.html + // https://github.com/apache/arrow/blob/cc4e2a54309813e6bbbb36ba50bcd22a7b71d3d9/python/pyarrow/ipc.pxi#L620 + let res = Python::with_gil(|py| -> PyResult> { + let py_batch_res = batch_iter.call_method0(py, "__next__"); + match py_batch_res { + Ok(py_batch) => Ok(Some(RecordBatch::from_pyarrow(py_batch.extract(py)?)?)), + Err(error) if error.is_instance::(py) => Ok(None), + Err(error) => Err(error), + } + }); + + match (self.limit, res) { + (Some(limit), Ok(Some(batch))) => { + // Handle limit parameter by stopping iterator early + let next_total = count + batch.num_rows(); + if next_total == limit { + send_result(&response_tx, Ok(batch))?; + break; + } else if next_total < limit { + count += batch.num_rows(); + send_result(&response_tx, Ok(batch))?; + } else { + count = limit; + send_result(&response_tx, Ok(batch.slice(0, limit - count)))?; + break; + } + } + (None, Ok(Some(batch))) => { + count += batch.num_rows(); + send_result(&response_tx, Ok(batch))?; + } + (_, Ok(None)) => { + break; + } + (_, Err(err)) => { + send_result(&response_tx, Err(ArrowError::IoError(err.to_string())))?; + } + } + } + + Ok(()) + } +} + +fn send_result( + response_tx: &Sender>, + result: ArrowResult, +) -> Result<()> { + // Note this function is running on its own blockng tokio thread so blocking here is ok. + response_tx + .blocking_send(result) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + Ok(()) +} + +/// Execution plan for scanning a PyArrow dataset +#[derive(Debug, Clone)] +pub struct PyArrowDatasetExec { + dataset: Py, + scanner: PyArrowDatasetScanner, + projection: Option>, + filter: Option, + limit: Option, + schema: SchemaRef, + projected_statistics: Statistics, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +#[async_trait] +impl ExecutionPlan for PyArrowDatasetExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.scanner.projected_schema() + } + + fn children(&self) -> Vec> { + // this is a leaf node and has no children + vec![] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + if children.is_empty() { + Ok(Arc::new(self.clone())) + } else { + Err(DataFusionError::Internal(format!( + "Children cannot be replaced in {:?}", + self + ))) + } + } + + async fn execute( + &self, + _partition_index: usize, + runtime: Arc, + ) -> Result { + // need to use runtime.batch_size + let (response_tx, response_rx): ( + Sender>, + Receiver>, + ) = channel(2); + + // Have to recreate with correct batch size + let scanner = PyArrowDatasetScanner::make( + self.dataset.clone(), + self.schema.clone(), + &self.projection, + self.filter.clone(), + self.limit, + runtime.batch_size, + )?; + + let join_handle = task::spawn_blocking(move || { + if let Err(e) = scanner.get_batches(response_tx) { + println!("Dataset scanner thread terminated due to error: {:?}", e); + } + }); + + Ok(RecordBatchReceiverStream::create( + &self.scanner.projected_schema(), + response_rx, + join_handle, + )) + } + + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!( + f, + "PyArrowDatasetExec: filter={:?}, limit={:?}, projection={:?} partitions=...", + self.filter, self.scanner.limit, self.projection + ) + } + } + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + self.projected_statistics.clone() + } +} + +struct PyArrowExprVisitor { + result_stack: Vec, +} + +impl ExpressionVisitor for PyArrowExprVisitor { + fn pre_visit(mut self, _expr: &Expr) -> Result> { + Ok(Recursion::Continue(self)) + } + + fn post_visit(mut self, expr: &Expr) -> Result { + let res = Python::with_gil(|py| -> PyResult<()> { + let ds = PyModule::import(py, "pyarrow.dataset")?; + let field = ds.getattr("field")?; + + match expr { + Expr::Column(col) => { + self.result_stack + .push(field.call1((col.name.clone(),))?.into()); + } + Expr::Literal(scalar) => { + self.result_stack.push(scalar.to_pyarrow(py)?); + } + Expr::BinaryExpr { + left: _, + right: _, + op, + } => { + // Must be pop'd in reverse order of visitation + let right_val = self.result_stack.pop().unwrap(); + let left_val = self.result_stack.pop().unwrap(); + + let method = match op { + Operator::Eq => Ok("__eq__"), + Operator::NotEq => Ok("__ne__"), + Operator::Lt => Ok("__lt__"), + Operator::LtEq => Ok("__le__"), + Operator::Gt => Ok("__gt__"), + Operator::GtEq => Ok("__gt__"), + Operator::Plus => Ok("__add__"), + Operator::Minus => Ok("__sub__"), + Operator::Multiply => Ok("__mul__"), + Operator::Divide => Ok("__div__"), + Operator::Modulo => Ok("__mod__"), + Operator::Or => Ok("__or__"), + Operator::And => Ok("__and__"), + _ => Err(PyNotImplementedError::new_err( + "Operation not yet supported", + )), + }; + + self.result_stack + .push(left_val.call_method1(py, method?, (right_val,))?); + } + Expr::Not(expr) => { + let val = self.result_stack.pop().unwrap(); + + self.result_stack.push(val.call_method0(py, "__not__")?); + } + Expr::Between { + expr: _, + negated, + low: _, + high: _, + } => { + // Must be pop'd in reverse order of visitation + let high_val = self.result_stack.pop().unwrap(); + let low_val = self.result_stack.pop().unwrap(); + let expr_val = self.result_stack.pop().unwrap(); + + let gte_val = expr_val.call_method1(py, "__ge__", (low_val,))?; + let lte_val = expr_val.call_method1(py, "__le__", (high_val,))?; + let mut val = gte_val.call_method1(py, "__and__", (lte_val,))?; + if *negated { + val = val.call_method0(py, "__not__")?; + } + self.result_stack.push(val); + } + _ => { + return Err(PyNotImplementedError::new_err( + "Expression not yet supported", + )); + } + } + Ok(()) + }); + + match res { + Ok(_) => Ok(self), + Err(err) => Err(DataFusionError::External(Box::new(err))), + } + } +} + +// TODO: replace with some Substrait conversion? +// https://github.com/apache/arrow-rs/blob/master/arrow/src/pyarrow.rs +fn expr_to_pyarrow(expr: &Expr) -> PyResult { + Python::with_gil(|py| -> PyResult { + let visitor = PyArrowExprVisitor { + result_stack: Vec::new(), + }; + + let mut final_visitor = expr.accept(visitor)?; + + match final_visitor.result_stack.len() { + 1 => Ok(final_visitor.result_stack.pop().unwrap()), + _ => Err(PyAssertionError::new_err("something went wrong")), + } + }) +} diff --git a/src/lib.rs b/src/lib.rs index d40bae2..f9ba393 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ use pyo3::prelude::*; mod catalog; mod context; mod dataframe; +mod dataset; mod errors; mod expression; mod functions; diff --git a/src/udaf.rs b/src/udaf.rs index 1de6e63..c25fd1f 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -49,16 +49,6 @@ impl Accumulator for RustAccumulator { .map_err(|e| DataFusionError::Execution(format!("{}", e))) } - fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { - // no need to implement as datafusion does not use it - todo!() - } - - fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { - // no need to implement as datafusion does not use it - todo!() - } - fn evaluate(&self) -> Result { Python::with_gil(|py| self.accum.as_ref(py).call_method0("evaluate")?.extract()) .map_err(|e| DataFusionError::Execution(format!("{}", e))) @@ -144,7 +134,6 @@ impl PyAggregateUDF { } /// creates a new PyExpr with the call of the udf - #[call] #[args(args = "*")] fn __call__(&self, args: Vec) -> PyResult { let args = args.iter().map(|e| e.expr.clone()).collect(); diff --git a/src/udf.rs b/src/udf.rs index 379c449..8251739 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -24,9 +24,7 @@ use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::DataFusionError; use datafusion::logical_plan; -use datafusion::physical_plan::functions::{ - make_scalar_function, ScalarFunctionImplementation, -}; +use datafusion::physical_plan::functions::{make_scalar_function, ScalarFunctionImplementation}; use datafusion::physical_plan::udf::ScalarUDF; use crate::expression::PyExpr; @@ -89,7 +87,6 @@ impl PyScalarUDF { } /// creates a new PyExpr with the call of the udf - #[call] #[args(args = "*")] fn __call__(&self, args: Vec) -> PyResult { let args = args.iter().map(|e| e.expr.clone()).collect();