diff --git a/adbc_example.py b/adbc_example.py new file mode 100644 index 0000000..f3e2ac6 --- /dev/null +++ b/adbc_example.py @@ -0,0 +1,32 @@ +import adbc_driver_duckdb.dbapi +import pyarrow +from substrait.dataframe import named_table, literal, col, scalar_function +from substrait.dataframe.functions import add + +data = pyarrow.record_batch( + [[1, 2, 3, 4], ["a", "b", "c", "d"]], + names=["ints", "strs"], +) + +with adbc_driver_duckdb.dbapi.connect(":memory:") as conn: + with conn.cursor() as cur: + cur.adbc_ingest("AnswerToEverything", data) + + cur.executescript("INSTALL substrait;") + cur.executescript("LOAD substrait;") + + table = named_table("AnswerToEverything", conn) + table = table.project( + literal(1001, type='i64').alias('BigNumber'), + col("ints").alias('BigNumber2') + ) + + table = table.project( + scalar_function("functions_arithmetic.yaml", "add", + add(col("BigNumber"), col("BigNumber2")), + col("BigNumber2") + ).alias('BigNumber3') + ) + + cur.execute(table.plan.SerializeToString()) + print(cur.fetch_arrow_table()) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4c4ab62..85af6b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ write_to = "src/substrait/_version.py" [project.optional-dependencies] extensions = ["antlr4-python3-runtime", "pyyaml"] gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"] -test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml"] +test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml", "duckdb==1.1.3", "adbc-driver-manager"] [tool.pytest.ini_options] pythonpath = "src" diff --git a/src/substrait/dataframe/__init__.py b/src/substrait/dataframe/__init__.py new file mode 100644 index 0000000..f2ffacb --- /dev/null +++ b/src/substrait/dataframe/__init__.py @@ -0,0 +1,61 @@ +from .dataframe import DataFrame +from typing import Any + +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +from substrait.dataframe.expression import UnboundExpression, UnboundFieldReference, UnboundLiteral, UnboundScalarFunction + +def literal(value: Any, type: str): + return UnboundLiteral(value, type) + +def col(column_name: str): + return UnboundFieldReference(column_name=column_name) + +def scalar_function(uri: str, function: str, *expressions: UnboundExpression): + return UnboundScalarFunction(uri, function, *expressions) + +def pyarrow_to_substrait_type(pa_type): + import pyarrow + + if pa_type == pyarrow.int64(): + return stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE)) + elif pa_type == pyarrow.float64(): + return stt.Type(fp64=stt.Type.FP64(nullability=stt.Type.NULLABILITY_NULLABLE)) + elif pa_type == pyarrow.string(): + return stt.Type( + string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE) + ) + + +def named_table(name, conn): + pa_schema = conn.adbc_get_table_schema(name) + + column_names = pa_schema.names + struct = stt.Type.Struct( + types=[ + pyarrow_to_substrait_type(pa_schema.field(c).type) for c in column_names + ], + nullability=stt.Type.Nullability.NULLABILITY_NULLABLE, + ) + + schema = stt.NamedStruct( + names=column_names, + struct=struct, + ) + + rel = stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), + base_schema=schema, + named_table=stalg.ReadRel.NamedTable(names=[name]), + ) + ) + + plan = stp.Plan( + relations=[ + stp.PlanRel(root=stalg.RelRoot(input=rel, names=column_names)) + ] + ) + + return DataFrame(plan=plan) \ No newline at end of file diff --git a/src/substrait/dataframe/dataframe.py b/src/substrait/dataframe/dataframe.py new file mode 100644 index 0000000..086ae50 --- /dev/null +++ b/src/substrait/dataframe/dataframe.py @@ -0,0 +1,58 @@ +from substrait.gen.proto import algebra_pb2 as stalg +from substrait.gen.proto import plan_pb2 as stp +from substrait.gen.proto import type_pb2 as stt +from substrait.gen.proto.extensions import extensions_pb2 as ste +from substrait.type_inference import infer_rel_schema +from substrait.dataframe.utils import merge_extensions + +class DataFrame: + def __init__(self, plan: stp.Plan, extensions: dict = {}): + self.plan = plan + self.extensions = extensions + + if extensions: + self.plan = stp.Plan( + extension_uris=[ + ste.SimpleExtensionURI(extension_uri_anchor=i, uri=e) + for i, e in enumerate(self.extensions.keys()) + ], + extensions=[ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=i, + function_anchor=fn_anchor, + name=fn_name, + ) + ) + for i, e in enumerate(self.extensions.items()) + for fn_name, fn_anchor in e[1].items() + ], + version=self.plan.version, + relations=self.plan.relations, + ) + + def schema(self) -> stt.Type.Struct: + return infer_rel_schema(self.plan.relations[-1].root.input) + + def project(self, *expressions): + bound_expressions = [e.bind(self) for e in expressions] + + rel = stalg.Rel( + project=stalg.ProjectRel( + input=self.plan.relations[-1].root.input, + expressions=[e.expression for e in bound_expressions] + ) + ) + + names = [e.alias for e in bound_expressions] + + plan = stp.Plan( + relations=[ + stp.PlanRel(root=stalg.RelRoot(input=rel, names=names)) + ] + ) + + return DataFrame(plan=plan, extensions=merge_extensions(self.extensions, *[e.extensions for e in bound_expressions])) + + + diff --git a/src/substrait/dataframe/expression.py b/src/substrait/dataframe/expression.py new file mode 100644 index 0000000..b74c8e6 --- /dev/null +++ b/src/substrait/dataframe/expression.py @@ -0,0 +1,119 @@ +from typing import Any +from substrait.gen.proto import algebra_pb2 as stalg +from substrait.function_registry import FunctionRegistry +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +from substrait.type_inference import infer_expression_type + + +class BoundExpression: + def __init__(self, expression: stalg.Expression, parent_schema: stt.Type.Struct, alias: str, extensions: dict): + self.expression = expression + self.extensions = extensions + self.parent_schema = parent_schema + self.alias = alias + + def extensions(self): + return self.extensions + + def dtype(self) -> stt.Type: + return infer_expression_type(self.expression, self.parent_schema) + +class UnboundExpression: + def bind(self, df): + pass + + def alias(self, alias: str): + self.alias = alias + return self + + +class UnboundLiteral(UnboundExpression): + def __init__(self, value: Any, type: str): + self.value = value + self.type = type + + def bind(self, df) -> stalg.Expression: + type = self.type + value = self.value + + if type == "boolean": + literal = stalg.Expression.Literal(boolean=value, nullable=True) + elif type in ("i8", "int8"): + literal = stalg.Expression.Literal(i8=value, nullable=True) + elif type in ("i16", "int16"): + literal = stalg.Expression.Literal(i16=value, nullable=True) + elif type in ("i32", "int32"): + literal = stalg.Expression.Literal(i32=value, nullable=True) + elif type in ("i64", "int64"): + literal = stalg.Expression.Literal(i64=value, nullable=True) + elif type == "fp32": + literal = stalg.Expression.Literal(fp32=value, nullable=True) + elif type == "fp64": + literal = stalg.Expression.Literal(fp64=value, nullable=True) + elif type == "string": + literal = stalg.Expression.Literal(string=value, nullable=True) + else: + raise Exception(f"Unknown literal type - {type}") + + return BoundExpression( + expression=stalg.Expression(literal=literal), + alias=self.alias, + parent_schema=df.schema(), + extensions={} + ) + + +class UnboundFieldReference(UnboundExpression): + def __init__(self, column_name: str): + self.column_name = column_name + + def bind(self, df) -> stalg.Expression: + col_names = list(df.plan.relations[-1].root.names) + + return BoundExpression( + expression=stalg.Expression( + selection=stalg.Expression.FieldReference( + root_reference=stalg.Expression.FieldReference.RootReference(), + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=col_names.index(self.column_name), + ), + ), + ), + ), + alias=self.alias, + parent_schema=df.schema(), + extensions={} + ) + +class UnboundScalarFunction(UnboundExpression): + def __init__(self, uri: str, function: str, *expressions: UnboundExpression): + self.uri = uri + self.function = function + self.expressions = expressions + + def bind(self, df): + registry = FunctionRegistry() + + bound_expressions = [e.bind(df) for e in self.expressions] + signature = [e.dtype() for e in bound_expressions] + + (func_entry, rtn) = registry.lookup_function( + uri=self.uri, + function_name=self.function, + signature=signature, + ) + + return BoundExpression( + expression=stalg.Expression(scalar_function=stalg.Expression.ScalarFunction( + function_reference=func_entry.anchor, + output_type=rtn, + arguments=[ + stalg.FunctionArgument(value=e.expression) for e in bound_expressions + ], + )), + alias=self.alias, + parent_schema=df.schema(), + extensions={func_entry.uri: {str(func_entry): func_entry.anchor}} + ) \ No newline at end of file diff --git a/src/substrait/dataframe/functions.py b/src/substrait/dataframe/functions.py new file mode 100644 index 0000000..396b3cc --- /dev/null +++ b/src/substrait/dataframe/functions.py @@ -0,0 +1,5 @@ +from . import scalar_function +from .expression import UnboundExpression + +def add(*expressions: UnboundExpression): + return scalar_function("functions_arithmetic.yaml", "add", *expressions) diff --git a/src/substrait/dataframe/utils.py b/src/substrait/dataframe/utils.py new file mode 100644 index 0000000..248f751 --- /dev/null +++ b/src/substrait/dataframe/utils.py @@ -0,0 +1,12 @@ + +def merge_extensions(*extensions): + ret = {} + for e in extensions: + for k, v in e.items(): + if k not in ret: + ret[k] = {} + + for k1, v1 in v.items(): + ret[k][k1] = v1 + + return ret \ No newline at end of file diff --git a/src/substrait/type_inference.py b/src/substrait/type_inference.py index 080af9a..f4ddffd 100644 --- a/src/substrait/type_inference.py +++ b/src/substrait/type_inference.py @@ -257,7 +257,8 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct: infer_expression_type(e, parent_schema) for e in rel.project.expressions ] raw_schema = stt.Type.Struct( - types=list(parent_schema.types) + expression_types, + # types=list(parent_schema.types) + expression_types, # This is on purpose to reflect the bug in duckdb substrait + types=list(expression_types), nullability=parent_schema.nullability, )