Skip to content

Commit 9262b49

Browse files
authored
Merge pull request #82 from dolthub/daylon/functions2
Added a framework for creating PostgreSQL functions
2 parents ae1f933 + 5dd1b94 commit 9262b49

15 files changed

+1183
-4
lines changed

CONTRIBUTING.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,24 @@ There are exceptions, as some statements we do not yet support, and cannot suppo
238238
In these cases, we must add a `//TODO:` comment stating what is missing and why it isn't an error.
239239
This will at least allow us to track all such instances where we deviate from the expected behavior, which we can also document elsewhere for users of DoltgreSQL.
240240

241+
### `server/functions`
242+
243+
The `functions` package contains the functions, along with an implementation to approximate the function overloading structure (and type coercion).
244+
245+
The function overloading structure is defined in all files that have the `zinternal_` prefix.
246+
Although not preferable, this was chosen as Go does not allow cyclical references between packages.
247+
Rather than have half of the implementation in `functions`, and the other half in another package, the decision was made to include both in the `functions` package with the added prefix for distinction.
248+
249+
There's an `init` function in `server/functions/zinternal_catalog.go` (this is included in `server/listener.go`) that removes any conflicting GMS function names, and replaces them with the PostgreSQL equivalents.
250+
This means that the functions that we've added behave as expected, and for others to have _some_ sort of implementation rather than outright failing.
251+
We will eventually remove all GMS functions once all PostgreSQL functions have been implemented.
252+
The other internal files all contribute to the generation of functions, along with their proper handling.
253+
254+
Each function (and all overloads) are contained in a single file.
255+
Overloads are named according to their parameters, and prefixed by their target function name.
256+
The set of overloads are then added to the `Catalog` within `server/functions/zinternal_catalog.go`.
257+
To add a new function, it is as simple as creating the `Function`, adding the overloads, and adding it to the `Catalog`.
258+
241259
### `testing/bats`
242260

243261
All Bats tests must follow this general structure:

server/ast/expr.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,8 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) {
189189
}
190190

191191
switch node.SyntaxMode {
192-
case tree.CastExplicit:
193-
// only acceptable cast type
194-
case tree.CastShort:
195-
return nil, fmt.Errorf("TYPECAST is not yet supported")
192+
case tree.CastExplicit, tree.CastShort:
193+
// Both of these are acceptable
196194
case tree.CastPrepend:
197195
return nil, fmt.Errorf("typed literals are not yet supported")
198196
default:

server/ast/resolvable_type_reference.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func nodeResolvableTypeReference(typ tree.ResolvableTypeReference) (*vitess.Conv
4646
columnTypeName = columnType.SQLStandardName()
4747
switch columnType.Family() {
4848
case types.DecimalFamily:
49+
columnTypeName = "decimal"
4950
columnTypeLength = vitess.NewIntVal([]byte(strconv.Itoa(int(columnType.Precision()))))
5051
columnTypeScale = vitess.NewIntVal([]byte(strconv.Itoa(int(columnType.Scale()))))
5152
case types.JsonFamily:

server/functions/cbrt.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright 2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package functions
16+
17+
import (
18+
"fmt"
19+
"math"
20+
)
21+
22+
// cbrt represents the PostgreSQL function of the same name.
23+
var cbrt = Function{
24+
Name: "cbrt",
25+
Overloads: []interface{}{cbrt_float},
26+
}
27+
28+
// cbrt_float is one of the overloads of cbrt.
29+
func cbrt_float(num FloatType) (FloatType, error) {
30+
if num.IsNull {
31+
return FloatType{IsNull: true}, nil
32+
}
33+
if num.OriginalType == ParameterType_String {
34+
return FloatType{}, fmt.Errorf("function cbrt(%s) does not exist", ParameterType_String.String())
35+
}
36+
return FloatType{Value: math.Cbrt(num.Value)}, nil
37+
}

server/functions/gcd.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright 2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package functions
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/doltgresql/utils"
21+
)
22+
23+
// gcd represents the PostgreSQL function of the same name.
24+
var gcd = Function{
25+
Name: "gcd",
26+
Overloads: []interface{}{gcd_int_int},
27+
}
28+
29+
// gcd_int_int is one of the overloads of gcd.
30+
func gcd_int_int(num1 IntegerType, num2 IntegerType) (IntegerType, error) {
31+
if num1.IsNull || num2.IsNull {
32+
return IntegerType{IsNull: true}, nil
33+
}
34+
if num1.OriginalType == ParameterType_String || num2.OriginalType == ParameterType_String {
35+
return IntegerType{}, fmt.Errorf("function gcd(%s, %s) does not exist",
36+
num1.OriginalType.String(), num2.OriginalType.String())
37+
}
38+
for num2.Value != 0 {
39+
temp := num2.Value
40+
num2.Value = num1.Value % num2.Value
41+
num1.Value = temp
42+
}
43+
return IntegerType{Value: utils.Abs(num1.Value)}, nil
44+
}

server/functions/lcm.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright 2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package functions
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/doltgresql/utils"
21+
)
22+
23+
// lcm represents the PostgreSQL function of the same name.
24+
var lcm = Function{
25+
Name: "lcm",
26+
Overloads: []interface{}{lcm1_int_int},
27+
}
28+
29+
// lcm1 is one of the overloads of lcm.
30+
func lcm1_int_int(num1 IntegerType, num2 IntegerType) (IntegerType, error) {
31+
if num1.IsNull || num2.IsNull {
32+
return IntegerType{IsNull: true}, nil
33+
}
34+
if num1.OriginalType == ParameterType_String || num2.OriginalType == ParameterType_String {
35+
return IntegerType{}, fmt.Errorf("function lcm(%s, %s) does not exist",
36+
num1.OriginalType.String(), num2.OriginalType.String())
37+
}
38+
gcdResult, err := gcd_int_int(num1, num2)
39+
if err != nil {
40+
return IntegerType{}, err
41+
}
42+
if gcdResult.Value == 0 {
43+
return IntegerType{Value: 0}, nil
44+
}
45+
return IntegerType{Value: utils.Abs((num1.Value * num2.Value) / gcdResult.Value)}, nil
46+
}

server/functions/round.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright 2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package functions
16+
17+
import "math"
18+
19+
// round represents the PostgreSQL function of the same name.
20+
var round = Function{
21+
Name: "round",
22+
Overloads: []interface{}{round_num, round_float, round_num_dec},
23+
}
24+
25+
// round1 is one of the overloads of round.
26+
func round_num(num NumericType) (NumericType, error) {
27+
if num.IsNull {
28+
return NumericType{IsNull: true}, nil
29+
}
30+
return NumericType{Value: math.Round(num.Value)}, nil
31+
}
32+
33+
// round2 is one of the overloads of round.
34+
func round_float(num FloatType) (FloatType, error) {
35+
if num.IsNull {
36+
return FloatType{IsNull: true}, nil
37+
}
38+
return FloatType{Value: math.RoundToEven(num.Value)}, nil
39+
}
40+
41+
// round3 is one of the overloads of round.
42+
func round_num_dec(num NumericType, decimalPlaces IntegerType) (NumericType, error) {
43+
if num.IsNull || decimalPlaces.IsNull {
44+
return NumericType{IsNull: true}, nil
45+
}
46+
ratio := math.Pow10(int(decimalPlaces.Value))
47+
return NumericType{Value: math.Round(num.Value*ratio) / ratio}, nil
48+
}

server/functions/zinternal_catalog.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright 2023 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package functions
16+
17+
import (
18+
"fmt"
19+
"reflect"
20+
"strings"
21+
22+
"github.com/dolthub/go-mysql-server/sql"
23+
"github.com/dolthub/go-mysql-server/sql/expression/function"
24+
)
25+
26+
// Function is a name, along with a collection of functions, that represent a single PostgreSQL function with all of its
27+
// overloads.
28+
type Function struct {
29+
Name string
30+
Overloads []any
31+
}
32+
33+
// Catalog contains all of the PostgreSQL functions. If a new function is added, make sure to add it to the catalog here.
34+
var Catalog = []Function{
35+
cbrt,
36+
gcd,
37+
lcm,
38+
round,
39+
}
40+
41+
// init handles the initialization of the catalog by overwriting the built-in GMS functions, since they do not apply to
42+
// PostgreSQL (and functions of the same name often have different behavior).
43+
func init() {
44+
catalogMap := make(map[string]struct{})
45+
for _, f := range Catalog {
46+
catalogMap[strings.ToLower(f.Name)] = struct{}{}
47+
}
48+
var newBuiltIns []sql.Function
49+
for _, f := range function.BuiltIns {
50+
if _, ok := catalogMap[strings.ToLower(f.FunctionName())]; !ok {
51+
newBuiltIns = append(newBuiltIns, f)
52+
}
53+
}
54+
function.BuiltIns = newBuiltIns
55+
56+
allNames := make(map[string]struct{})
57+
for _, catalogItem := range Catalog {
58+
funcName := strings.ToLower(catalogItem.Name)
59+
if _, ok := allNames[funcName]; ok {
60+
panic("duplicate name: " + catalogItem.Name)
61+
}
62+
allNames[funcName] = struct{}{}
63+
64+
baseOverload := &OverloadDeduction{}
65+
for _, functionOverload := range catalogItem.Overloads {
66+
// For each function overload, we first need to ensure that it has an acceptable signature
67+
funcVal := reflect.ValueOf(functionOverload)
68+
if !funcVal.IsValid() || funcVal.IsNil() {
69+
panic(fmt.Errorf("function `%s` has an invalid item", catalogItem.Name))
70+
}
71+
if funcVal.Kind() != reflect.Func {
72+
panic(fmt.Errorf("function `%s` has a non-function item", catalogItem.Name))
73+
}
74+
if funcVal.Type().NumOut() != 2 {
75+
panic(fmt.Errorf("function `%s` has an overload that does not return two values", catalogItem.Name))
76+
}
77+
if funcVal.Type().Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
78+
panic(fmt.Errorf("function `%s` has an overload that does not return an error", catalogItem.Name))
79+
}
80+
returnValType, returnSqlType, ok := ParameterTypeFromReflection(funcVal.Type().Out(0))
81+
if !ok {
82+
panic(fmt.Errorf("function `%s` has an overload that returns as invalid type (`%s`)",
83+
catalogItem.Name, funcVal.Type().Out(0).String()))
84+
}
85+
86+
// Loop through all of the parameters to ensure uniqueness, then store it
87+
currentOverload := baseOverload
88+
for i := 0; i < funcVal.Type().NumIn(); i++ {
89+
paramValType, _, ok := ParameterTypeFromReflection(funcVal.Type().In(i))
90+
if !ok {
91+
panic(fmt.Errorf("function `%s` has an overload with an invalid parameter type (`%s`)",
92+
catalogItem.Name, funcVal.Type().In(i).String()))
93+
}
94+
nextOverload := currentOverload.Parameter[paramValType]
95+
if nextOverload == nil {
96+
nextOverload = &OverloadDeduction{}
97+
currentOverload.Parameter[paramValType] = nextOverload
98+
}
99+
currentOverload = nextOverload
100+
}
101+
if currentOverload.Function.IsValid() && !currentOverload.Function.IsNil() {
102+
panic(fmt.Errorf("function `%s` has duplicate overloads", catalogItem.Name))
103+
}
104+
currentOverload.Function = funcVal
105+
currentOverload.ReturnValType = returnValType
106+
currentOverload.ReturnSqlType = returnSqlType
107+
}
108+
109+
// Store the compiled function into the engine's built-in functions
110+
function.BuiltIns = append(function.BuiltIns, sql.FunctionN{
111+
Name: funcName,
112+
Fn: func(params ...sql.Expression) (sql.Expression, error) {
113+
return &CompiledFunction{
114+
Name: catalogItem.Name,
115+
Parameters: params,
116+
Functions: baseOverload,
117+
}, nil
118+
},
119+
})
120+
}
121+
}

0 commit comments

Comments
 (0)