Skip to content

Commit deb936f

Browse files
committed
Working through issues between custom catalog and build in schema
1 parent bef8f5c commit deb936f

File tree

4 files changed

+75
-6
lines changed

4 files changed

+75
-6
lines changed

python/datafusion/catalog.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from __future__ import annotations
2121

2222
from abc import ABC, abstractmethod
23-
from typing import TYPE_CHECKING
23+
from typing import TYPE_CHECKING, Protocol
2424

2525
import datafusion._internal as df_internal
2626

@@ -174,7 +174,9 @@ def schema(self, name: str) -> Schema | None:
174174
"""Retrieve a specific schema from this catalog."""
175175
...
176176

177-
def register_schema(self, name: str, schema: Schema) -> None: # noqa: B027
177+
def register_schema( # noqa: B027
178+
self, name: str, schema: SchemaProviderExportable | SchemaProvider | Schema
179+
) -> None:
178180
"""Add a schema to this catalog.
179181
180182
This method is optional. If your catalog provides a fixed list of schemas, you
@@ -229,3 +231,12 @@ def deregister_table(self, name, cascade: bool) -> None: # noqa: B027
229231
def table_exist(self, name: str) -> bool:
230232
"""Returns true if the table exists in this schema."""
231233
...
234+
235+
236+
class SchemaProviderExportable(Protocol):
237+
"""Type hint for object that has __datafusion_schema_provider__ PyCapsule.
238+
239+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.SchemaProvider.html
240+
"""
241+
242+
def __datafusion_schema_provider__(self) -> object: ...

python/datafusion/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
except ImportError:
2727
from typing_extensions import deprecated # Python 3.12
2828

29-
from datafusion.catalog import Catalog, Table
29+
from datafusion.catalog import Catalog, CatalogProvider, Table
3030
from datafusion.dataframe import DataFrame
3131
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
3232
from datafusion.record_batch import RecordBatchStream
@@ -760,7 +760,7 @@ def catalog_names(self) -> set[str]:
760760
return self.ctx.catalog_names()
761761

762762
def register_catalog_provider(
763-
self, name: str, provider: CatalogProviderExportable | Catalog
763+
self, name: str, provider: CatalogProviderExportable | CatalogProvider | Catalog
764764
) -> None:
765765
"""Register a catalog provider."""
766766
if isinstance(provider, Catalog):

python/tests/test_catalog.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,51 @@ def test_python_table_provider(ctx: SessionContext):
162162
schema.deregister_table("table3")
163163
schema.register_table("table4", create_dataset())
164164
assert schema.table_names() == {"table4"}
165+
166+
167+
def test_in_end_to_end_python_providers(ctx: SessionContext):
168+
"""Test registering all python providers and running a query against them."""
169+
170+
all_catalog_names = [
171+
"datafusion",
172+
"custom_catalog",
173+
"in_mem_catalog",
174+
]
175+
176+
all_schema_names = [
177+
"custom_schema",
178+
"in_mem_schema",
179+
]
180+
181+
ctx.register_catalog_provider(all_catalog_names[1], CustomCatalogProvider())
182+
ctx.register_catalog_provider(
183+
all_catalog_names[2], dfn.catalog.Catalog.memory_catalog()
184+
)
185+
186+
for catalog_name in all_catalog_names:
187+
catalog = ctx.catalog(catalog_name)
188+
189+
# Clean out previous schemas if they exist so we can start clean
190+
for schema_name in catalog.schema_names():
191+
catalog.deregister_schema(schema_name, cascade=False)
192+
193+
catalog.register_schema(all_schema_names[0], CustomSchemaProvider())
194+
catalog.register_schema(all_schema_names[1], dfn.catalog.Schema.memory_schema())
195+
196+
for schema_name in all_schema_names:
197+
schema = catalog.schema(schema_name)
198+
199+
for table_name in schema.table_names():
200+
schema.deregister_table(table_name)
201+
202+
schema.register_table("test_table", create_dataset())
203+
204+
for catalog_name in all_catalog_names:
205+
for schema_name in all_schema_names:
206+
table_full_name = f"{catalog_name}.{schema_name}.test_table"
207+
208+
batches = ctx.sql(f"select * from {table_full_name}").collect()
209+
210+
assert len(batches) == 1
211+
assert batches[0].column(0) == pa.array([1, 2, 3])
212+
assert batches[0].column(1) == pa.array([4, 5, 6])

src/catalog.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,19 @@ impl RustWrappedPySchemaProvider {
314314

315315
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider>))
316316
} else {
317-
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
317+
if let Ok(inner_table) = py_table.getattr("table") {
318+
if let Ok(inner_table) = inner_table.extract::<PyTable>() {
319+
return Ok(Some(inner_table.table));
320+
}
321+
}
318322

319-
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
323+
match py_table.extract::<PyTable>() {
324+
Ok(py_table) => Ok(Some(py_table.table)),
325+
Err(_) => {
326+
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
327+
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
328+
}
329+
}
320330
}
321331
})
322332
}

0 commit comments

Comments
 (0)