Skip to content

feat: dataframe api #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions adbc_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import adbc_driver_duckdb.dbapi
import pyarrow
from substrait.dataframe import named_table, literal, col, scalar_function
from substrait.dataframe.functions import add

data = pyarrow.record_batch(
[[1, 2, 3, 4], ["a", "b", "c", "d"]],
names=["ints", "strs"],
)

with adbc_driver_duckdb.dbapi.connect(":memory:") as conn:
with conn.cursor() as cur:
cur.adbc_ingest("AnswerToEverything", data)

cur.executescript("INSTALL substrait;")
cur.executescript("LOAD substrait;")

table = named_table("AnswerToEverything", conn)
table = table.project(
literal(1001, type='i64').alias('BigNumber'),
col("ints").alias('BigNumber2')
)

table = table.project(
scalar_function("functions_arithmetic.yaml", "add",
add(col("BigNumber"), col("BigNumber2")),
col("BigNumber2")
).alias('BigNumber3')
)

cur.execute(table.plan.SerializeToString())
print(cur.fetch_arrow_table())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ write_to = "src/substrait/_version.py"
[project.optional-dependencies]
extensions = ["antlr4-python3-runtime", "pyyaml"]
gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"]
test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml"]
test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml", "duckdb==1.1.3", "adbc-driver-manager"]

[tool.pytest.ini_options]
pythonpath = "src"
Expand Down
61 changes: 61 additions & 0 deletions src/substrait/dataframe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from .dataframe import DataFrame
from typing import Any

import substrait.gen.proto.algebra_pb2 as stalg
import substrait.gen.proto.type_pb2 as stt
import substrait.gen.proto.plan_pb2 as stp
from substrait.dataframe.expression import UnboundExpression, UnboundFieldReference, UnboundLiteral, UnboundScalarFunction

def literal(value: Any, type: str):
return UnboundLiteral(value, type)

def col(column_name: str):
return UnboundFieldReference(column_name=column_name)

def scalar_function(uri: str, function: str, *expressions: UnboundExpression):
return UnboundScalarFunction(uri, function, *expressions)

def pyarrow_to_substrait_type(pa_type):
import pyarrow

if pa_type == pyarrow.int64():
return stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE))
elif pa_type == pyarrow.float64():
return stt.Type(fp64=stt.Type.FP64(nullability=stt.Type.NULLABILITY_NULLABLE))
elif pa_type == pyarrow.string():
return stt.Type(
string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)
)


def named_table(name, conn):
pa_schema = conn.adbc_get_table_schema(name)

column_names = pa_schema.names
struct = stt.Type.Struct(
types=[
pyarrow_to_substrait_type(pa_schema.field(c).type) for c in column_names
],
nullability=stt.Type.Nullability.NULLABILITY_NULLABLE,
)

schema = stt.NamedStruct(
names=column_names,
struct=struct,
)

rel = stalg.Rel(
read=stalg.ReadRel(
common=stalg.RelCommon(direct=stalg.RelCommon.Direct()),
base_schema=schema,
named_table=stalg.ReadRel.NamedTable(names=[name]),
)
)

plan = stp.Plan(
relations=[
stp.PlanRel(root=stalg.RelRoot(input=rel, names=column_names))
]
)

return DataFrame(plan=plan)
58 changes: 58 additions & 0 deletions src/substrait/dataframe/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from substrait.gen.proto import algebra_pb2 as stalg
from substrait.gen.proto import plan_pb2 as stp
from substrait.gen.proto import type_pb2 as stt
from substrait.gen.proto.extensions import extensions_pb2 as ste
from substrait.type_inference import infer_rel_schema
from substrait.dataframe.utils import merge_extensions

class DataFrame:
def __init__(self, plan: stp.Plan, extensions: dict = {}):
self.plan = plan
self.extensions = extensions

if extensions:
self.plan = stp.Plan(
extension_uris=[
ste.SimpleExtensionURI(extension_uri_anchor=i, uri=e)
for i, e in enumerate(self.extensions.keys())
],
extensions=[
ste.SimpleExtensionDeclaration(
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
extension_uri_reference=i,
function_anchor=fn_anchor,
name=fn_name,
)
)
for i, e in enumerate(self.extensions.items())
for fn_name, fn_anchor in e[1].items()
],
version=self.plan.version,
relations=self.plan.relations,
)

def schema(self) -> stt.Type.Struct:
return infer_rel_schema(self.plan.relations[-1].root.input)

def project(self, *expressions):
bound_expressions = [e.bind(self) for e in expressions]

rel = stalg.Rel(
project=stalg.ProjectRel(
input=self.plan.relations[-1].root.input,
expressions=[e.expression for e in bound_expressions]
)
)

names = [e.alias for e in bound_expressions]

plan = stp.Plan(
relations=[
stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))
]
)

return DataFrame(plan=plan, extensions=merge_extensions(self.extensions, *[e.extensions for e in bound_expressions]))



119 changes: 119 additions & 0 deletions src/substrait/dataframe/expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Any
from substrait.gen.proto import algebra_pb2 as stalg
from substrait.function_registry import FunctionRegistry
import substrait.gen.proto.algebra_pb2 as stalg
import substrait.gen.proto.type_pb2 as stt
from substrait.type_inference import infer_expression_type


class BoundExpression:
def __init__(self, expression: stalg.Expression, parent_schema: stt.Type.Struct, alias: str, extensions: dict):
self.expression = expression
self.extensions = extensions
self.parent_schema = parent_schema
self.alias = alias

def extensions(self):
return self.extensions

def dtype(self) -> stt.Type:
return infer_expression_type(self.expression, self.parent_schema)

class UnboundExpression:
def bind(self, df):
pass

def alias(self, alias: str):
self.alias = alias
return self


class UnboundLiteral(UnboundExpression):
def __init__(self, value: Any, type: str):
self.value = value
self.type = type

def bind(self, df) -> stalg.Expression:
type = self.type
value = self.value

if type == "boolean":
literal = stalg.Expression.Literal(boolean=value, nullable=True)
elif type in ("i8", "int8"):
literal = stalg.Expression.Literal(i8=value, nullable=True)
elif type in ("i16", "int16"):
literal = stalg.Expression.Literal(i16=value, nullable=True)
elif type in ("i32", "int32"):
literal = stalg.Expression.Literal(i32=value, nullable=True)
elif type in ("i64", "int64"):
literal = stalg.Expression.Literal(i64=value, nullable=True)
elif type == "fp32":
literal = stalg.Expression.Literal(fp32=value, nullable=True)
elif type == "fp64":
literal = stalg.Expression.Literal(fp64=value, nullable=True)
elif type == "string":
literal = stalg.Expression.Literal(string=value, nullable=True)
else:
raise Exception(f"Unknown literal type - {type}")

return BoundExpression(
expression=stalg.Expression(literal=literal),
alias=self.alias,
parent_schema=df.schema(),
extensions={}
)


class UnboundFieldReference(UnboundExpression):
def __init__(self, column_name: str):
self.column_name = column_name

def bind(self, df) -> stalg.Expression:
col_names = list(df.plan.relations[-1].root.names)

return BoundExpression(
expression=stalg.Expression(
selection=stalg.Expression.FieldReference(
root_reference=stalg.Expression.FieldReference.RootReference(),
direct_reference=stalg.Expression.ReferenceSegment(
struct_field=stalg.Expression.ReferenceSegment.StructField(
field=col_names.index(self.column_name),
),
),
),
),
alias=self.alias,
parent_schema=df.schema(),
extensions={}
)

class UnboundScalarFunction(UnboundExpression):
def __init__(self, uri: str, function: str, *expressions: UnboundExpression):
self.uri = uri
self.function = function
self.expressions = expressions

def bind(self, df):
registry = FunctionRegistry()

bound_expressions = [e.bind(df) for e in self.expressions]
signature = [e.dtype() for e in bound_expressions]

(func_entry, rtn) = registry.lookup_function(
uri=self.uri,
function_name=self.function,
signature=signature,
)

return BoundExpression(
expression=stalg.Expression(scalar_function=stalg.Expression.ScalarFunction(
function_reference=func_entry.anchor,
output_type=rtn,
arguments=[
stalg.FunctionArgument(value=e.expression) for e in bound_expressions
],
)),
alias=self.alias,
parent_schema=df.schema(),
extensions={func_entry.uri: {str(func_entry): func_entry.anchor}}
)
5 changes: 5 additions & 0 deletions src/substrait/dataframe/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import scalar_function
from .expression import UnboundExpression

def add(*expressions: UnboundExpression):
return scalar_function("functions_arithmetic.yaml", "add", *expressions)
12 changes: 12 additions & 0 deletions src/substrait/dataframe/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

def merge_extensions(*extensions):
ret = {}
for e in extensions:
for k, v in e.items():
if k not in ret:
ret[k] = {}

for k1, v1 in v.items():
ret[k][k1] = v1

return ret
3 changes: 2 additions & 1 deletion src/substrait/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct:
infer_expression_type(e, parent_schema) for e in rel.project.expressions
]
raw_schema = stt.Type.Struct(
types=list(parent_schema.types) + expression_types,
# types=list(parent_schema.types) + expression_types, # This is on purpose to reflect the bug in duckdb substrait
types=list(expression_types),
nullability=parent_schema.nullability,
)

Expand Down
Loading