Skip to content

Commit bef8f5c

Browse files
committed
Defining abstract methods for catalog and schema providers
1 parent 25200ab commit bef8f5c

File tree

4 files changed

+107
-22
lines changed

4 files changed

+107
-22
lines changed

python/datafusion/catalog.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from __future__ import annotations
2121

22+
from abc import ABC, abstractmethod
2223
from typing import TYPE_CHECKING
2324

2425
import datafusion._internal as df_internal
@@ -121,6 +122,8 @@ def table(self, name: str) -> Table:
121122

122123
def register_table(self, name, table) -> None:
123124
"""Register a table provider in this schema."""
125+
if isinstance(table, Table):
126+
return self._raw_schema.register_table(name, table.table)
124127
return self._raw_schema.register_table(name, table)
125128

126129
def deregister_table(self, name: str) -> None:
@@ -144,6 +147,11 @@ def __repr__(self) -> str:
144147
"""Print a string representation of the table."""
145148
return self.table.__repr__()
146149

150+
@staticmethod
151+
def from_dataset(dataset: pa.dataset.Dataset) -> Table:
152+
"""Turn a pyarrow Dataset into a Table."""
153+
return Table(df_internal.catalog.RawTable.from_dataset(dataset))
154+
147155
@property
148156
def schema(self) -> pa.Schema:
149157
"""Returns the schema associated with this table."""
@@ -153,3 +161,71 @@ def schema(self) -> pa.Schema:
153161
def kind(self) -> str:
154162
"""Returns the kind of table."""
155163
return self.table.kind
164+
165+
166+
class CatalogProvider(ABC):
167+
@abstractmethod
168+
def schema_names(self) -> set[str]:
169+
"""Set of the names of all schemas in this catalog."""
170+
...
171+
172+
@abstractmethod
173+
def schema(self, name: str) -> Schema | None:
174+
"""Retrieve a specific schema from this catalog."""
175+
...
176+
177+
def register_schema(self, name: str, schema: Schema) -> None: # noqa: B027
178+
"""Add a schema to this catalog.
179+
180+
This method is optional. If your catalog provides a fixed list of schemas, you
181+
do not need to implement this method.
182+
"""
183+
184+
def deregister_schema(self, name: str, cascade: bool) -> None: # noqa: B027
185+
"""Remove a schema from this catalog.
186+
187+
This method is optional. If your catalog provides a fixed list of schemas, you
188+
do not need to implement this method.
189+
190+
Args:
191+
name: The name of the schema to remove.
192+
cascade: If true, deregister the tables within the schema.
193+
"""
194+
195+
196+
class SchemaProvider(ABC):
197+
def owner_name(self) -> str | None:
198+
"""Returns the owner of the schema.
199+
200+
This is an optional method. The default return is None.
201+
"""
202+
return None
203+
204+
@abstractmethod
205+
def table_names(self) -> set[str]:
206+
"""Set of the names of all tables in this schema."""
207+
...
208+
209+
@abstractmethod
210+
def table(self, name: str) -> Table | None:
211+
"""Retrieve a specific table from this schema."""
212+
...
213+
214+
def register_table(self, name: str, table: Table) -> None: # noqa: B027
215+
"""Add a table from this schema.
216+
217+
This method is optional. If your schema provides a fixed list of tables, you do
218+
not need to implement this method.
219+
"""
220+
221+
def deregister_table(self, name, cascade: bool) -> None: # noqa: B027
222+
"""Remove a table from this schema.
223+
224+
This method is optional. If your schema provides a fixed list of tables, you do
225+
not need to implement this method.
226+
"""
227+
228+
@abstractmethod
229+
def table_exist(self, name: str) -> bool:
230+
"""Returns true if the table exists in this schema."""
231+
...

python/tests/test_catalog.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
from __future__ import annotations
1718

1819
import datafusion as dfn
1920
import pyarrow as pa
@@ -46,20 +47,16 @@ def test_basic(ctx, database):
4647
)
4748

4849

49-
class CustomTableProvider:
50-
def __init__(self):
51-
pass
52-
53-
54-
def create_dataset() -> pa.dataset.Dataset:
50+
def create_dataset() -> Table:
5551
batch = pa.RecordBatch.from_arrays(
5652
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
5753
names=["a", "b"],
5854
)
59-
return ds.dataset([batch])
55+
dataset = ds.dataset([batch])
56+
return Table.from_dataset(dataset)
6057

6158

62-
class CustomSchemaProvider:
59+
class CustomSchemaProvider(dfn.catalog.SchemaProvider):
6360
def __init__(self):
6461
self.tables = {"table1": create_dataset()}
6562

@@ -72,8 +69,14 @@ def register_table(self, name: str, table: Table):
7269
def deregister_table(self, name, cascade: bool = True):
7370
del self.tables[name]
7471

72+
def table(self, name: str) -> Table | None:
73+
return self.tables[name]
74+
75+
def table_exist(self, name: str) -> bool:
76+
return name in self.tables
77+
7578

76-
class CustomCatalogProvider:
79+
class CustomCatalogProvider(dfn.catalog.CatalogProvider):
7780
def __init__(self):
7881
self.schemas = {"my_schema": CustomSchemaProvider()}
7982

src/catalog.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,14 @@ impl PySchema {
207207
let provider: ForeignTableProvider = provider.into();
208208
Arc::new(provider) as Arc<dyn TableProvider>
209209
} else {
210-
let py = table_provider.py();
211-
let provider = Dataset::new(&table_provider, py)?;
212-
Arc::new(provider) as Arc<dyn TableProvider>
210+
match table_provider.extract::<PyTable>() {
211+
Ok(py_table) => py_table.table,
212+
Err(_) => {
213+
let py = table_provider.py();
214+
let provider = Dataset::new(&table_provider, py)?;
215+
Arc::new(provider) as Arc<dyn TableProvider>
216+
}
217+
}
213218
};
214219

215220
let _ = self
@@ -238,6 +243,14 @@ impl PyTable {
238243
self.table.schema().to_pyarrow(py)
239244
}
240245

246+
#[staticmethod]
247+
fn from_dataset(py: Python<'_>, dataset: &Bound<'_, PyAny>) -> PyResult<Self> {
248+
let ds = Arc::new(Dataset::new(dataset, py).map_err(py_datafusion_err)?)
249+
as Arc<dyn TableProvider>;
250+
251+
Ok(Self::new(ds))
252+
}
253+
241254
/// Get the type of this table for metadata/catalog purposes.
242255
#[getter]
243256
fn kind(&self) -> &str {

src/context.rs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -628,17 +628,10 @@ impl PySessionContext {
628628
let provider: ForeignCatalogProvider = provider.into();
629629
Arc::new(provider) as Arc<dyn CatalogProvider>
630630
} else {
631-
println!("Provider has type {}", provider.get_type());
632631
match provider.extract::<PyCatalog>() {
633-
Ok(py_catalog) => {
634-
println!("registering an existing PyCatalog");
635-
py_catalog.catalog
636-
}
637-
Err(_) => {
638-
println!("registering a rust wrapped catalog provider");
639-
Arc::new(RustWrappedPyCatalogProvider::new(provider.into()))
640-
as Arc<dyn CatalogProvider>
641-
}
632+
Ok(py_catalog) => py_catalog.catalog,
633+
Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(provider.into()))
634+
as Arc<dyn CatalogProvider>,
642635
}
643636
};
644637

0 commit comments

Comments
 (0)