Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
anshuldata committed Sep 25, 2024
1 parent c4d2cb1 commit e98af72
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 4 deletions.
4 changes: 2 additions & 2 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,14 @@ func minArgumentCount(paramTypeList ArgumentList, variadicBehavior *VariadicBeha
if variadicBehavior == nil {
return len(paramTypeList)
}
return variadicBehavior.Min
return len(paramTypeList) + variadicBehavior.Min
}

func maxArgumentCount(paramTypeList ArgumentList, variadicBehavior *VariadicBehavior) int {
if variadicBehavior == nil {
return len(paramTypeList)
}
return variadicBehavior.Max
return len(paramTypeList) + variadicBehavior.Max
}

// NewScalarFuncVariant constructs a variant with the provided name and uri
Expand Down
134 changes: 132 additions & 2 deletions functions/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -956,8 +956,8 @@ scalar_functions:
// pass third argument as variadic, it should match against last argument type
argTypes := []types.Type{int64Nullable, int32Nullable, int32Nullable}
require.Len(t, fv, 1)
require.Equal(t, 1, fv[0].MinArgumentCount())
require.Equal(t, 3, fv[0].MaxArgumentCount())
require.Equal(t, 3, fv[0].MinArgumentCount())
require.Equal(t, 5, fv[0].MaxArgumentCount())
match, err := fv[0].Match(argTypes)
require.NoError(t, err)
assert.True(t, match)
Expand Down Expand Up @@ -1146,3 +1146,133 @@ scalar_functions:
// even though function argument allows decimal(P, S)
assert.False(t, match)
}

func TestAggregateFuncMinMax(t *testing.T) {
const uri = "http://localhost/sample.yaml"
const defYaml = `---
aggregate_functions:
-
name: "func_nonvariadic"
description: "Add two values."
impls:
- args:
- name: x
value: i32
- name: y
value: i32
return: i32
-
name: "func_variadic"
description: "Add two values."
impls:
- args:
- name: x
value: i32
- name: y
value: i32
variadic:
min: 1
max: 3
return: i32
`

dialectYaml := `
name: test
type: sql
dependencies:
arithmetic:
http://localhost/sample.yaml
supported_types:
i32:
sql_type_name: INTEGER
aggregate_functions:
- name: arithmetic.func_nonvariadic
supported_kernels:
- i32_i32
- name: arithmetic.func_variadic
supported_kernels:
- i32_i32
`
// get substrait function registry
var c extensions.Collection
require.NoError(t, c.Load(uri, strings.NewReader(defYaml)))
funcRegistry := NewFunctionRegistry(&c)
localRegistry := getLocalFunctionRegistry(t, dialectYaml, funcRegistry)

// test non-variadic min-max
fv := localRegistry.GetAggregateFunctions(LocalFunctionName("func_nonvariadic"), 2)
require.Len(t, fv, 1)
require.Equal(t, 2, fv[0].MinArgumentCount())
require.Equal(t, 2, fv[0].MaxArgumentCount())

// test variadic min-max
fv = localRegistry.GetAggregateFunctions(LocalFunctionName("func_variadic"), 2)
require.Len(t, fv, 1)
require.Equal(t, 3, fv[0].MinArgumentCount())
require.Equal(t, 5, fv[0].MaxArgumentCount())
}

func TestWindowFuncMinMax(t *testing.T) {
const uri = "http://localhost/sample.yaml"
const defYaml = `---
window_functions:
-
name: "func_nonvariadic"
description: "Add two values."
impls:
- args:
- name: x
value: i32
- name: y
value: i32
return: i32
-
name: "func_variadic"
description: "Add two values."
impls:
- args:
- name: x
value: i32
- name: y
value: i32
variadic:
min: 1
max: 3
return: i32
`

dialectYaml := `
name: test
type: sql
dependencies:
arithmetic:
http://localhost/sample.yaml
supported_types:
i32:
sql_type_name: INTEGER
window_functions:
- name: arithmetic.func_nonvariadic
supported_kernels:
- i32_i32
- name: arithmetic.func_variadic
supported_kernels:
- i32_i32
`
// get substrait function registry
var c extensions.Collection
require.NoError(t, c.Load(uri, strings.NewReader(defYaml)))
funcRegistry := NewFunctionRegistry(&c)
localRegistry := getLocalFunctionRegistry(t, dialectYaml, funcRegistry)

// test non-variadic min-max
fv := localRegistry.GetWindowFunctions(LocalFunctionName("func_nonvariadic"), 2)
require.Len(t, fv, 1)
require.Equal(t, 2, fv[0].MinArgumentCount())
require.Equal(t, 2, fv[0].MaxArgumentCount())

// test variadic min-max
fv = localRegistry.GetWindowFunctions(LocalFunctionName("func_variadic"), 2)
require.Len(t, fv, 1)
require.Equal(t, 3, fv[0].MinArgumentCount())
require.Equal(t, 5, fv[0].MaxArgumentCount())
}

0 comments on commit e98af72

Please sign in to comment.