Skip to content

Commit

Permalink
Add Match and MatchAt API to FunctionVariant interface
Browse files Browse the repository at this point in the history
* Also fixed default nullability for ScalarFunctionImpl
  • Loading branch information
anshuldata committed Sep 12, 2024
1 parent 2240ec9 commit 94687c0
Show file tree
Hide file tree
Showing 7 changed files with 398 additions and 6 deletions.
4 changes: 4 additions & 0 deletions extensions/extension_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path"
"sort"

"github.com/creasty/defaults"
"github.com/goccy/go-yaml"
substraitgo "github.com/substrait-io/substrait-go"
"github.com/substrait-io/substrait-go/proto/extensions"
Expand Down Expand Up @@ -179,6 +180,9 @@ func (c *Collection) Load(uri string, r io.Reader) error {
simpleNames := make(map[string]string)

for _, f := range file.ScalarFunctions {
if err := defaults.Set(&f); err != nil {
return fmt.Errorf("failure setting defaults for scalar functions: %w", err)
}
addToMaps[*ScalarFunctionVariant](id, &f, c.scalarMap, simpleNames)
}

Expand Down
2 changes: 2 additions & 0 deletions extensions/extension_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ func TestCollection_GetAllScalarFunctions(t *testing.T) {
sf, ok := c.GetScalarFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
assert.True(t, ok)
assert.Contains(t, scalarFunctions, sf)
// verify that default nullability is set to MIRROR
assert.Equal(t, extensions.MirrorNullability, sf.Nullability())
}
if tt.isAggregate {
af, ok := c.GetAggregateFunc(extensions.ID{URI: tt.uri, Name: tt.signature})
Expand Down
2 changes: 1 addition & 1 deletion extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ type ScalarFunctionImpl struct {
Variadic *VariadicBehavior `yaml:",omitempty"`
SessionDependent bool `yaml:"sessionDependent,omitempty"`
Deterministic bool `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty" default:"MIRROR"`
Return parser.TypeExpression `yaml:",omitempty"`
Implementation map[string]string `yaml:",omitempty"`
}
Expand Down
105 changes: 105 additions & 0 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type FunctionVariant interface {
URI() string
ResolveType(argTypes []types.Type) (types.Type, error)
Variadic() *VariadicBehavior
Match(argumentTypes []types.Type) (bool, error)
MatchAt(typ types.Type, pos int) (bool, error)
}

func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeExpression, paramTypeList ArgumentList, actualTypes []types.Type) (types.Type, error) {
Expand Down Expand Up @@ -84,6 +86,78 @@ func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeEx
return outType, nil
}

// TODO: Handle Variadic function
func matchArguments(nullability NullabilityHandling, paramTypeList ArgumentList, actualTypes []types.Type) (bool, error) {
if len(paramTypeList) != len(actualTypes) {
return false, nil
}
funcDefArgList, err := getFuncDefFromArgList(paramTypeList)
if err != nil {
return false, nil
}
for argPos := range paramTypeList {
match, err1 := matchArgumentAtCommon(actualTypes[argPos], argPos, nullability, funcDefArgList)
if err1 != nil {
return false, err1
}
if !match {
return false, nil
}
}
return true, nil
}

// TODO: Handle Variadic function
func matchArgumentAt(actualType types.Type, argPos int, nullability NullabilityHandling, paramTypeList ArgumentList) (bool, error) {
if argPos < 0 {
return false, fmt.Errorf("non-zero argument position")
}
if argPos >= len(paramTypeList) {
return false, fmt.Errorf("%w: argument position %d out of range", substraitgo.ErrNotFound, argPos)
}
funcDefArgList, err := getFuncDefFromArgList(paramTypeList)
if err != nil {
return false, nil
}
return matchArgumentAtCommon(actualType, argPos, nullability, funcDefArgList)
}

func matchArgumentAtCommon(actualType types.Type, argPos int, nullability NullabilityHandling, funcDefArgList []types.FuncDefArgType) (bool, error) {
if HasSyncParams(funcDefArgList) {
return false, fmt.Errorf("%w: function has sync params", substraitgo.ErrNotImplemented)
}
funcDefArg := funcDefArgList[argPos]
switch nullability {
case DiscreteNullability:
return funcDefArg.MatchWithNullability(actualType), nil
case MirrorNullability, DeclaredOutputNullability:
return funcDefArg.MatchWithoutNullability(actualType), nil
}
// unreachable case
return false, fmt.Errorf("invalid nullability type: %s", nullability)
}

func getFuncDefFromArgList(paramTypeList ArgumentList) ([]types.FuncDefArgType, error) {
var out []types.FuncDefArgType
for argPos, param := range paramTypeList {
switch paramType := param.(type) {
case ValueArg:
funcDefArgType, err := paramType.Value.Expr.(*parser.Type).ArgType()
if err != nil {
return nil, err
}
out = append(out, funcDefArgType)
case EnumArg:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
case TypeArg:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
default:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
}
}
return out, nil
}

func parseFuncName(compoundName string) (name string, args ArgumentList) {
name, argsStr, _ := strings.Cut(compoundName, ":")
if len(argsStr) == 0 {
Expand All @@ -102,6 +176,25 @@ func parseFuncName(compoundName string) (name string, args ArgumentList) {
return name, args
}

// Match this function matches input arguments against definition of this functions argument list
// returns (true, nil) if all input argument can type replace the function definition argument
// returns (false, err) for invalid input argument. For e.g. if input argument nullability is not correctly
// set this function will return error
// returns (false, nil) valid input argument type and no match this function returns
func (s *ScalarFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, argumentTypes)
}

// MatchAt this function matches input argument at position against definition of this
// functions argument at same position
// returns (true, nil) if all input argument can type replace the function definition argument
// returns (false, err) for invalid input argument. For e.g. if input argument nullability is not correctly
// set this function will return error
// returns (false, nil) valid input argument type and no match this function returns
func (s *ScalarFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args)
}

// NewScalarFuncVariant constructs a variant with the provided name and uri
// and uses the defaults for everything else.
//
Expand Down Expand Up @@ -268,6 +361,12 @@ func (s *AggregateFunctionVariant) Intermediate() (types.FuncDefArgType, error)
}
func (s *AggregateFunctionVariant) Ordered() bool { return s.impl.Ordered }
func (s *AggregateFunctionVariant) MaxSet() int { return s.impl.MaxSet }
func (s *AggregateFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, argumentTypes)
}
func (s *AggregateFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args)
}

type WindowFunctionVariant struct {
name string
Expand Down Expand Up @@ -376,6 +475,12 @@ func (s *WindowFunctionVariant) Intermediate() (types.FuncDefArgType, error) {
func (s *WindowFunctionVariant) Ordered() bool { return s.impl.Ordered }
func (s *WindowFunctionVariant) MaxSet() int { return s.impl.MaxSet }
func (s *WindowFunctionVariant) WindowType() WindowType { return s.impl.WindowType }
func (s *WindowFunctionVariant) Match(argumentTypes []types.Type) (bool, error) {
return matchArguments(s.Nullability(), s.impl.Args, argumentTypes)
}
func (s *WindowFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args)
}

// HasSyncParams This API returns if params share a leaf param name
func HasSyncParams(params []types.FuncDefArgType) bool {
Expand Down
Loading

0 comments on commit 94687c0

Please sign in to comment.