Skip to content

Commit

Permalink
Merge branch 'main' into anshul/VirtualTableMigrateToNewField
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroshade authored Nov 12, 2024
2 parents 4823940 + 833624e commit 3be7754
Show file tree
Hide file tree
Showing 7 changed files with 446 additions and 1 deletion.
3 changes: 3 additions & 0 deletions expr/binding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ func TestBoundExpressions(t *testing.T) {
&types.Int32Type{Nullability: types.NullabilityNullable}},
{MustExpr(NewRootFieldRef(NewStructFieldRef(10), &boringSchema.Struct)), false,
&types.StringType{}},
{MustExpr(NewRootFieldRefFromType(
NewStructFieldRef(10), &types.StringType{})), false,
&types.StringType{}},
{MustExpr(NewScalarFunc(extReg, subID, nil,
NewPrimitiveLiteral(int8(1), false),
NewPrimitiveLiteral(int8(5), false))), false,
Expand Down
29 changes: 29 additions & 0 deletions expr/field_reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,11 @@ func NewRootFieldRef(ref Reference, baseSchema *types.StructType) (*FieldReferen
return NewFieldRef(RootReference, ref, baseSchema)
}

// NewRootFieldRefFromType creates a new field reference with a specific known type. Prefer using NewRootFieldRef.
func NewRootFieldRefFromType(ref Reference, t types.Type) (*FieldReference, error) {
return NewFieldRefFromType(RootReference, ref, t)
}

func NewFieldRef(root RootRefType, ref Reference, baseSchema *types.StructType) (*FieldReference, error) {
if ref != nil && root == RootReference && baseSchema == nil {
return nil, fmt.Errorf("%w: must provide the base schema to create a root field ref",
Expand Down Expand Up @@ -573,6 +578,30 @@ func NewFieldRef(root RootRefType, ref Reference, baseSchema *types.StructType)
return nil, substraitgo.ErrNotImplemented
}

// NewFieldRefFromType creates a new field reference with a specific known type. Prefer using NewFieldRef.
func NewFieldRefFromType(root RootRefType, ref Reference, t types.Type) (*FieldReference, error) {
switch ref.(type) {
case ReferenceSegment:
if root == RootReference {
// Nothing to do.
} else if _, ok := root.(Expression); ok {
// Nothing to do.
} else {
return nil, fmt.Errorf("%w: unknown root reference type %v",
substraitgo.ErrInvalidExpr, root)
}

return &FieldReference{
Reference: ref,
Root: root,
knownType: t,
}, nil
case *MaskExpression:
}

return nil, substraitgo.ErrNotImplemented
}

func (*FieldReference) isRootRef() {}

func (f *FieldReference) String() string {
Expand Down
37 changes: 37 additions & 0 deletions plan/builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ type Builder interface {
Join(left, right Rel, condition expr.Expression, joinType JoinType) (*JoinRel, error)
NamedScanRemap(tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableReadRel, error)
NamedScan(tableName []string, schema types.NamedStruct) *NamedTableReadRel
NamedWriteRemap(input Rel, op WriteOp, tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableWriteRel, error)
// NamedInsert inserts data from the input relation into a named table.
NamedInsert(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error)
// NamedDelete deletes rows from a specified named table based on the
// provided input relation, which typically includes conditions that filter
// the rows to delete.
NamedDelete(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error)
VirtualTableRemap(fields []string, remap []int32, values ...expr.StructLiteralValue) (*VirtualTableReadRel, error)
VirtualTable(fields []string, values ...expr.StructLiteralValue) (*VirtualTableReadRel, error)
SortRemap(input Rel, remap []int32, sorts ...expr.SortField) (*SortRel, error)
Expand Down Expand Up @@ -456,6 +463,36 @@ func (b *builder) Join(left, right Rel, condition expr.Expression, joinType Join
return b.JoinAndFilterRemap(left, right, condition, nil, joinType, nil)
}

func (b *builder) NamedWriteRemap(input Rel, op WriteOp, tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableWriteRel, error) {
if input == nil {
return nil, errNilInputRel
}

nOutput := int32(len(input.Remap(input.RecordType()).Types))
for _, idx := range remap {
if idx < 0 || idx >= nOutput {
return nil, errOutputMappingOutOfRange
}
}

return &NamedTableWriteRel{
RelCommon: RelCommon{mapping: remap},
names: tableName,
tableSchema: schema,
op: op,
input: input,
outputMode: OutputModeNoOutput,
}, nil
}

func (b *builder) NamedInsert(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) {
return b.NamedWriteRemap(input, WriteOpInsert, tableName, schema, nil)
}

func (b *builder) NamedDelete(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) {
return b.NamedWriteRemap(input, WriteOpDelete, tableName, schema, nil)
}

func (b *builder) NamedScanRemap(tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableReadRel, error) {
noutput := int32(len(schema.Struct.Types))
for _, idx := range remap {
Expand Down
79 changes: 79 additions & 0 deletions plan/named_write_plan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package plan_test

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/literal"
"github.com/substrait-io/substrait-go/plan"
"github.com/substrait-io/substrait-go/types"
)

// getFilterForTest1 returns filter rel for "name LIKE 'Alice'"
func getFilterForTest1(t *testing.T, b plan.Builder) plan.Rel {
namedTableReadRel := b.NamedScan([]string{"employee_salaries"}, employeeSalariesSchema)

// column 0 from the output of namedTableReadRel is name
// Build the filter with condition `name LIKE 'Alice'`
l, err := literal.NewString("Alice")
require.NoError(t, err)
nameLikeAlice := makeConditionExprForLike(t, b, namedTableReadRel, 0, l)
return makeFilterRel(t, b, namedTableReadRel, nameLikeAlice)
}

// TestNamedTableInsertRoundTrip verifies that generated plans match the expected JSON.
func TestNamedTableInsertRoundTrip(t *testing.T) {
for _, td := range []struct {
name string
tableName []string
tableSchema types.NamedStruct
getInputRel func(t *testing.T, b plan.Builder) plan.Rel
}{
{"insert_from_select", []string{"main", "employee_salaries"}, employeeSalariesSchema, getProjectionForTest1},
} {
t.Run(td.name, func(t *testing.T) {
// Load the expected JSON. This will be our baseline for comparison.
expectedJson, err := testdata.ReadFile(fmt.Sprintf("testdata/%s.json", td.name))
require.NoError(t, err)

// build plan for Insert
b := plan.NewBuilderDefault()
namedInsertRel, err := b.NamedInsert(td.getInputRel(t, b), td.tableName, td.tableSchema)
require.NoError(t, err)
namedInsertPlan, err := b.Plan(namedInsertRel, nil)
require.NoError(t, err)

// Check that the generated plan matches the expected JSON.
checkRoundTrip(t, string(expectedJson), namedInsertPlan)
})
}
}

// TestNamedTableDeleteRoundTrip verifies that generated plans match the expected JSON.
func TestNamedTableDeleteRoundTrip(t *testing.T) {
for _, td := range []struct {
name string
tableName []string
tableSchema types.NamedStruct
getInputRel func(t *testing.T, b plan.Builder) plan.Rel
}{
{"delete_with_filter", []string{"main", "employee_salaries"}, employeeSalariesSchema, getFilterForTest1},
} {
t.Run(td.name, func(t *testing.T) {
// Load the expected JSON. This will be our baseline for comparison.
expectedJson, err := testdata.ReadFile(fmt.Sprintf("testdata/%s.json", td.name))
require.NoError(t, err)

// build plan for Delete
b := plan.NewBuilderDefault()
namedDeleteRel, err := b.NamedDelete(td.getInputRel(t, b), td.tableName, td.tableSchema)
require.NoError(t, err)
namedDeletePlan, err := b.Plan(namedDeleteRel, nil)
require.NoError(t, err)

// Check that the generated plan matches the expected JSON.
checkRoundTrip(t, string(expectedJson), namedDeletePlan)
})
}
}
2 changes: 1 addition & 1 deletion plan/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ func RelFromProto(rel *proto.Rel, reg expr.ExtensionRegistry) (Rel, error) {
out.fromProtoCommon(rel.Write.Common)
}
switch rel.Write.Op {
case proto.WriteRel_WRITE_OP_CTAS:
case proto.WriteRel_WRITE_OP_CTAS, proto.WriteRel_WRITE_OP_INSERT, proto.WriteRel_WRITE_OP_DELETE:
switch writeType := rel.Write.WriteType.(type) {
case *proto.WriteRel_NamedTable:
out.names = writeType.NamedTable.Names
Expand Down
145 changes: 145 additions & 0 deletions plan/testdata/delete_with_filter.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
{
"extensionUris":[
{
"extensionUriAnchor":1,
"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_string.yaml"
}
],
"extensions":[
{
"extensionFunction":{
"extensionUriReference":1,
"functionAnchor":1,
"name":"contains:str_str"
}
}
],
"relations":[
{
"root":{
"input":{
"write":{
"common":{
"direct":{

}
},
"namedTable":{
"names":[
"main",
"employee_salaries"
]
},
"tableSchema":{
"names":[
"name",
"salary"
],
"struct":{
"types":[
{
"string":{
"nullability":"NULLABILITY_NULLABLE"
}
},
{
"decimal":{
"scale":2,
"precision":10,
"nullability":"NULLABILITY_NULLABLE"
}
}
]
}
},
"op":"WRITE_OP_DELETE",
"input":{
"filter":{
"common":{
"direct":{

}
},
"input":{
"read":{
"baseSchema":{
"names":[
"name",
"salary"
],
"struct":{
"types":[
{
"string":{
"nullability":"NULLABILITY_NULLABLE"
}
},
{
"decimal":{
"scale":2,
"precision":10,
"nullability":"NULLABILITY_NULLABLE"
}
}
]
}
},
"common":{
"direct":{

}
},
"namedTable":{
"names":[
"employee_salaries"
]
}
}
},
"condition":{
"scalarFunction":{
"functionReference":1,
"outputType":{
"bool":{
"nullability":"NULLABILITY_NULLABLE"
}
},
"arguments":[
{
"value":{
"selection":{
"directReference":{
"structField":{
"field":0
}
},
"rootReference":{

}
}
}
},
{
"value":{
"literal":{
"string":"Alice"
}
}
}
]
}
}
}
}
}
}
}
}
],
"version":{
"majorNumber":0,
"minorNumber":29,
"patchNumber":0,
"producer":"substrait-go"
}
}
Loading

0 comments on commit 3be7754

Please sign in to comment.