Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion datafusion-cli/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use datafusion::common::{plan_err, Column};
use datafusion::datasource::function::TableFunctionImpl;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::SessionState;
use datafusion::logical_expr::Expr;
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
Expand Down Expand Up @@ -317,7 +318,11 @@ fn fixed_len_byte_array_to_string(val: Option<&FixedLenByteArray>) -> Option<Str
pub struct ParquetMetadataFunc {}

impl TableFunctionImpl for ParquetMetadataFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
fn call(
&self,
_state: &SessionState,
exprs: &[Expr],
) -> Result<Arc<dyn TableProvider>> {
let filename = match exprs.first() {
Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet')
Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet")
Expand Down
7 changes: 6 additions & 1 deletion datafusion-examples/examples/simple_udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion::datasource::function::TableFunctionImpl;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::context::ExecutionProps;
use datafusion::execution::SessionState;
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -130,7 +131,11 @@ impl TableProvider for LocalCsvTable {
struct LocalCsvTableFunc {}

impl TableFunctionImpl for LocalCsvTableFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
fn call(
&self,
_state: &SessionState,
exprs: &[Expr],
) -> Result<Arc<dyn TableProvider>> {
let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else {
return plan_err!("read_csv requires at least one string argument");
};
Expand Down
13 changes: 10 additions & 3 deletions datafusion/core/src/datasource/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

//! A table that uses a function to generate data

use crate::execution::SessionState;

use super::TableProvider;

use datafusion_common::Result;
Expand All @@ -27,7 +29,8 @@ use std::sync::Arc;
/// A trait for table function implementations
pub trait TableFunctionImpl: Sync + Send {
/// Create a table provider
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>>;
fn call(&self, state: &SessionState, args: &[Expr])
-> Result<Arc<dyn TableProvider>>;
}

/// A table that uses a function to generate data
Expand Down Expand Up @@ -55,7 +58,11 @@ impl TableFunction {
}

/// Get the function implementation and generate a table
pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
self.fun.call(args)
pub fn create_table_provider(
&self,
state: &SessionState,
args: &[Expr],
) -> Result<Arc<dyn TableProvider>> {
self.fun.call(state, args)
}
}
2 changes: 1 addition & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ impl SessionContext {
Arc::clone(&factory) as Arc<dyn UrlTableFactory>,
));
let new_state = SessionStateBuilder::new_from_existing(self.state())
.with_catalog_list(catalog_list)
.with_catalog_list(Some(catalog_list))
.build();
let ctx = SessionContext::new_with_state(new_state);
factory.session_store().with_state(ctx.state_weak_ref());
Expand Down
13 changes: 7 additions & 6 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ impl SessionState {
SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_catalog_list(catalog_list)
.with_catalog_list(Some(catalog_list))
.with_default_features()
.build()
}
Expand All @@ -296,7 +296,7 @@ impl SessionState {
SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_catalog_list(catalog_list)
.with_catalog_list(Some(catalog_list))
.with_default_features()
.build()
}
Expand Down Expand Up @@ -932,6 +932,7 @@ impl SessionState {
/// be used for all values unless explicitly provided.
///
/// See example on [`SessionState`]
#[derive(Clone)]
pub struct SessionStateBuilder {
session_id: Option<String>,
analyzer: Option<Analyzer>,
Expand Down Expand Up @@ -1140,9 +1141,9 @@ impl SessionStateBuilder {
/// Set the [`CatalogProviderList`]
pub fn with_catalog_list(
mut self,
catalog_list: Arc<dyn CatalogProviderList>,
catalog_list: Option<Arc<dyn CatalogProviderList>>,
) -> Self {
self.catalog_list = Some(catalog_list);
self.catalog_list = catalog_list;
self
}

Expand Down Expand Up @@ -1543,7 +1544,7 @@ impl ContextProvider for SessionContextProvider<'_> {
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
let provider = tbl_func.create_table_provider(&args)?;
let provider = tbl_func.create_table_provider(self.state, &args)?;

Ok(provider_as_source(provider))
}
Expand Down Expand Up @@ -1876,7 +1877,7 @@ mod tests {
let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;

let session_state = SessionStateBuilder::new()
.with_catalog_list(Arc::new(MemoryCatalogProviderList::new()))
.with_catalog_list(Some(Arc::new(MemoryCatalogProviderList::new())))
.build();
let table_ref = session_state.resolve_table_ref("employee").to_string();
session_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::function::TableFunctionImpl;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::TaskContext;
use datafusion::execution::{SessionState, TaskContext};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::{collect, ExecutionPlan};
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -194,7 +194,11 @@ impl SimpleCsvTable {
struct SimpleCsvTableFunc {}

impl TableFunctionImpl for SimpleCsvTableFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
fn call(
&self,
_state: &SessionState,
exprs: &[Expr],
) -> Result<Arc<dyn TableProvider>> {
let mut new_exprs = vec![];
let mut filepath = String::new();
for expr in exprs {
Expand Down
6 changes: 5 additions & 1 deletion docs/source/library-user-guide/adding-udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -562,14 +562,18 @@ In the `call` method, you parse the input `Expr`s and return a `TableProvider`.
```rust
use datafusion::common::plan_err;
use datafusion::datasource::function::TableFunctionImpl;
use datafusion::execution::SessionState;
// Other imports here

/// A table function that returns a table provider with the value as a single column
#[derive(Default)]
pub struct EchoFunction {}

impl TableFunctionImpl for EchoFunction {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
fn call(&self,
_state: &SessionState,
exprs: &[Expr],
) -> Result<Arc<dyn TableProvider>> {
let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else {
return plan_err!("First argument must be an integer");
};
Expand Down
Loading