Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execute substrait plans #1041

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
e142cd2
Upgrade to DataFusion 16.0.0
andygrove Jan 18, 2023
0d038b2
fix
andygrove Jan 18, 2023
f03818d
fix todo
andygrove Jan 18, 2023
c5f06b4
Merge branch 'main' of https://github.com/dask-contrib/dask-sql into …
ayushdg Jan 20, 2023
a0fbcbf
Fix clippy style errors
ayushdg Jan 20, 2023
711efb3
Update getOffset and isUnbounded to better handle ScalarValue windowB…
ayushdg Jan 20, 2023
3e97e09
Window Plan: Fix field name order handling since the order now matche…
ayushdg Jan 20, 2023
06c4951
Fix comment location
ayushdg Jan 20, 2023
b4e8b0a
Predicate Filters: Df16 reverts changes in df15 and returns a simpler…
ayushdg Jan 20, 2023
c3c526b
testing
jdye64 Jan 24, 2023
eced31c
Merge remote-tracking branch 'origin/main' into datafusion-16
charlesbluca Jan 24, 2023
68cf931
bump datafusion -> 16.1.0
jdye64 Jan 24, 2023
64f6338
merge update
jdye64 Jan 24, 2023
8e01ac5
Fix 2 failing join pytests
jdye64 Jan 25, 2023
b6e3146
cargo fmt --all
jdye64 Jan 25, 2023
9c2923e
Ok really fix cargo fmt this time
jdye64 Jan 25, 2023
3e7acc8
Update join.rs to accept CAST as a JOIN condition type
jdye64 Jan 25, 2023
dc57fb9
Merge remote-tracking branch 'origin/main' into datafusion-16
charlesbluca Jan 25, 2023
61857bc
Add explicit handling for RexAlias
charlesbluca Jan 25, 2023
345449d
merge with upstream/main
jdye64 Jan 26, 2023
1cf9b5d
Merge branch 'datafusion-16' of github.com:andygrove/dask-sql into da…
jdye64 Jan 26, 2023
2af8d66
Bump DataFusion -> 17.0.0
jdye64 Jan 30, 2023
66734fa
first pass at bindings for substrait
jdye64 Jan 31, 2023
2e44f0a
test
jdye64 Feb 1, 2023
6875aa3
clippy checks
jdye64 Feb 9, 2023
54d27ce
merge with upstream/main
jdye64 Feb 9, 2023
975b592
Merge remote-tracking branch 'origin/main' into datafusion-16
charlesbluca Feb 9, 2023
1c61577
Fix clippy warnings
charlesbluca Feb 9, 2023
cfc3c21
Add protoc action to workflows for rust and style
jdye64 Feb 9, 2023
62eee30
use arrow re-exported from datafusion
jdye64 Feb 9, 2023
2987e7c
Add protoc action tot conda.yml
jdye64 Feb 9, 2023
d1949fc
Add protobuf install to import / conda build testing
charlesbluca Feb 9, 2023
bd92280
Enable RUST_BACKTRACE=1
jdye64 Feb 9, 2023
955e53b
Add RUST_BACKTRACE=1 to meta.yaml
charlesbluca Feb 9, 2023
5718d6d
Add Rust Toolchain action to conda.yml
jdye64 Feb 9, 2023
4877e5e
Merge branch 'datafusion-16' of github.com:andygrove/dask-sql into da…
jdye64 Feb 9, 2023
f224ab0
Bump setuptools-rust version to 1.5.2 in hopes that helps
jdye64 Feb 9, 2023
ea5f5b8
Enable RUST_BACKTRACE=full
jdye64 Feb 10, 2023
2fae72a
Remove substrait package for now
jdye64 Feb 10, 2023
e324a79
Updates for BinaryExpr in join.rs and comment out FilterColumnsPostJo…
jdye64 Feb 12, 2023
7503e38
Remove protoc github action
jdye64 Feb 12, 2023
41ab6d1
Disable pytest that was using a disabled optimizer rule
jdye64 Feb 12, 2023
a4418fd
merge with upstream
jdye64 Feb 13, 2023
66c1f4a
Add flag to context.sql for allowing substrait plans to be ran
jdye64 Feb 13, 2023
00762e0
merge upstream/main
jdye64 Feb 13, 2023
28d8dc3
remove protobuf
jdye64 Feb 13, 2023
d7508a8
remove rust toolchain that isn't needed
jdye64 Feb 13, 2023
2a69c45
remove println statements
jdye64 Feb 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
421 changes: 419 additions & 2 deletions dask_planner/Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dask_planner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ datafusion-common = "17.0.0"
datafusion-expr = "17.0.0"
datafusion-optimizer = "17.0.0"
datafusion-sql = "17.0.0"
datafusion-substrait = "17.0.0"
env_logger = "0.10"
log = "^0.4"
mimalloc = { version = "*", default-features = false }
Expand Down
96 changes: 91 additions & 5 deletions dask_planner/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@ pub mod statement;
pub mod table;
pub mod types;

use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, future::Future, sync::Arc};

use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::{
arrow::datatypes::{DataType, Field, Schema, TimeUnit},
catalog::{
catalog::{CatalogProvider, MemoryCatalogProvider},
schema::MemorySchemaProvider,
},
datasource::TableProvider,
prelude::SessionContext,
};
use datafusion_common::{config::ConfigOptions, DFSchema, DataFusionError};
use datafusion_expr::{
logical_plan::Extension,
Expand All @@ -34,7 +42,9 @@ use datafusion_sql::{
ResolvedTableReference,
TableReference,
};
use datafusion_substrait::{consumer, serializer};
use pyo3::prelude::*;
use tokio::runtime::Runtime;

use self::logical::{
create_catalog_schema::CreateCatalogSchemaPlanNode,
Expand Down Expand Up @@ -63,6 +73,7 @@ use crate::{
show_tables::ShowTablesPlanNode,
PyLogicalPlan,
},
table::DaskTableSource,
},
};

Expand All @@ -86,12 +97,13 @@ use crate::{
/// # }
/// ```
#[pyclass(name = "DaskSQLContext", module = "dask_planner", subclass)]
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct DaskSQLContext {
current_catalog: String,
current_schema: String,
schemas: HashMap<String, schema::DaskSchema>,
options: ConfigOptions,
session_ctx: SessionContext,
}

impl ContextProvider for DaskSQLContext {
Expand All @@ -108,6 +120,7 @@ impl ContextProvider for DaskSQLContext {
reference.catalog
)));
}

match self.schemas.get(reference.schema) {
Some(schema) => {
let mut resp = None;
Expand Down Expand Up @@ -411,6 +424,7 @@ impl DaskSQLContext {
current_schema: default_schema_name.to_owned(),
schemas: HashMap::new(),
options: ConfigOptions::new(),
session_ctx: SessionContext::new(),
}
}

Expand All @@ -432,7 +446,27 @@ impl DaskSQLContext {
schema_name: String,
schema: schema::DaskSchema,
) -> PyResult<bool> {
self.schemas.insert(schema_name, schema);
self.schemas.insert(schema_name.clone(), schema);

match self.session_ctx.catalog(&self.current_catalog) {
Some(catalog) => {
let schema_provider = MemorySchemaProvider::new();
let _result = catalog.register_schema(&schema_name, Arc::new(schema_provider));

self.session_ctx
.register_catalog(self.current_catalog.clone(), catalog);
}
None => {
let mem_catalog = MemoryCatalogProvider::new();
let schema_provider = MemorySchemaProvider::new();
let _result = mem_catalog.register_schema(&schema_name, Arc::new(schema_provider));

// Insert the new schema into this newly created catalog
self.session_ctx
.register_catalog(self.current_catalog.clone(), Arc::new(mem_catalog));
}
}

Ok(true)
}

Expand All @@ -444,7 +478,30 @@ impl DaskSQLContext {
) -> PyResult<bool> {
match self.schemas.get_mut(&schema_name) {
Some(schema) => {
schema.add_table(table);
schema.add_table(table.clone());

let tbl_ref = TableReference::Partial {
schema: &self.current_schema,
table: table.table_name.as_str(),
};
let tbl_src = self.get_table_provider(tbl_ref).unwrap();
let provider = tbl_src
.as_any()
.downcast_ref::<DaskTableSource>()
.expect("Invalid DefaulTableSource instance");
let tbl_provider = provider.provider.clone() as Arc<dyn TableProvider>;

let catalog = self.session_ctx.catalog(&self.current_catalog).unwrap();
let schema = catalog.schema(&table.schema_name.unwrap()).unwrap();
let _result = schema.register_table(table.table_name.clone(), tbl_provider.clone());

let bare_tbl_ref = TableReference::Bare {
table: table.table_name.as_str(),
};
let _result = self
.session_ctx
.register_table(bare_tbl_ref, tbl_provider.clone());

Ok(true)
}
None => Err(py_runtime_err(format!(
Expand Down Expand Up @@ -509,10 +566,39 @@ impl DaskSQLContext {
Err(e) => Err(py_optimization_exp(e)),
}
}

/// Loads a `LogicalPlan` from a local Substrait protobuf file.
pub fn plan_from_substrait(
&self,
plan_path: String,
py: Python,
) -> PyResult<logical::PyLogicalPlan> {
let result = serializer::deserialize(plan_path.as_str());
let plan = Self::wait_for_future(py, result).unwrap();

let result = Self::wait_for_future(
py,
consumer::from_substrait_plan(&mut self.session_ctx.clone(), &plan),
)
.map_err(DataFusionError::from)
.unwrap();

Ok(PyLogicalPlan::from(result))
}
}

/// non-Python methods
impl DaskSQLContext {
/// Utility to collect rust futures with GIL released
pub fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
where
F: Send,
F::Output: Send,
{
let rt = Runtime::new().unwrap();
py.allow_threads(|| rt.block_on(f))
}

/// Creates a non-optimized Relational Algebra LogicalPlan from an AST Statement
pub fn _logical_relational_algebra(
&self,
Expand Down
11 changes: 9 additions & 2 deletions dask_planner/src/sql/table.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::{any::Any, sync::Arc};

use async_trait::async_trait;
use datafusion::arrow::datatypes::{DataType, Field, SchemaRef};
use datafusion::{
arrow::datatypes::{DataType, Field, SchemaRef},
datasource::empty::EmptyTable,
};
use datafusion_common::DFField;
use datafusion_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableSource};
use datafusion_optimizer::utils::split_conjunction;
Expand All @@ -25,12 +28,16 @@ use crate::{
/// DaskTable wrapper that is compatible with DataFusion logical query plans
pub struct DaskTableSource {
schema: SchemaRef,
pub provider: Arc<EmptyTable>,
}

impl DaskTableSource {
/// Initialize a new `EmptyTable` from a schema.
pub fn new(schema: SchemaRef) -> Self {
Self { schema }
Self {
schema: schema.clone(),
provider: Arc::new(EmptyTable::new(schema)),
}
}
}

Expand Down
23 changes: 16 additions & 7 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def sql(
return_futures: bool = True,
dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None,
gpu: bool = False,
substrait: bool = False,
config_options: Dict[str, Any] = None,
) -> Union[dd.DataFrame, pd.DataFrame]:
"""
Expand All @@ -483,6 +484,9 @@ def sql(
to register before executing this query
gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU;
requires cuDF / dask-cuDF if enabled. Defaults to False.
substrait (:obj:`str`): If True the `sql` argument specifies a path to a Substrait plan file which is loaded
and ran as is without any optimizations. Otherwise it is treated as a standard SQL string and parsed by
the parsing engine.
config_options (:obj:`Dict[str,Any]`): Specific configuration options to pass during
query execution
Returns:
Expand All @@ -493,14 +497,19 @@ def sql(
for df_name, df in dataframes.items():
self.create_table(df_name, df, gpu=gpu)

if isinstance(sql, str):
rel, _ = self._get_ral(sql)
elif isinstance(sql, LogicalPlan):
rel = sql
if substrait:
logger.debug(f"Executing query using substrait plan: '{sql}'")
plan = self.context.plan_from_substrait(sql)
print(f"LogicalPlan from substrait: \n{plan}")
else:
raise RuntimeError(
f"Encountered unsupported `LogicalPlan` sql type: {type(sql)}"
)
if isinstance(sql, str):
rel, _ = self._get_ral(sql)
elif isinstance(sql, LogicalPlan):
rel = sql
else:
raise RuntimeError(
f"Encountered unsupported `LogicalPlan` sql type: {type(sql)}"
)

return self._compute_table_from_rel(rel, return_futures)

Expand Down
Empty file added df_simple.json
Empty file.
Binary file added tests/integration/proto/df_simple.proto
Binary file not shown.
12 changes: 12 additions & 0 deletions tests/integration/test_substrait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pandas as pd

from tests.utils import assert_eq


def test_usertable_substrait_join(c):
return_df = c.sql("./tests/integration/proto/df_simple.proto", substrait=True)
expected_df = pd.DataFrame(
{"user_id": [1, 1, 2, 2], "b": [3, 3, 1, 3], "c": [1, 2, 3, 3]}
)

assert_eq(return_df, expected_df, check_index=False)