diff --git a/go.mod b/go.mod index fcf48ff0be..6245fe73d7 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/madflojo/testcerts v1.1.1 github.com/mitchellh/go-ps v1.0.0 github.com/mitchellh/go-wordwrap v1.0.1 + github.com/pganalyze/pg_query_go/v6 v6.0.0 github.com/pierrre/geohash v1.0.0 github.com/pkg/profile v1.5.0 github.com/sergi/go-diff v1.1.0 diff --git a/go.sum b/go.sum index 7698945b60..d52fed4e65 100644 --- a/go.sum +++ b/go.sum @@ -729,6 +729,8 @@ github.com/pborman/getopt v0.0.0-20180729010549-6fdd0a2c7117/go.mod h1:85jBQOZwp github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= +github.com/pganalyze/pg_query_go/v6 v6.0.0 h1:in6RkR/apfqlAtvqgDxd4Y4o87a5Pr8fkKDB4DrDo2c= +github.com/pganalyze/pg_query_go/v6 v6.0.0/go.mod h1:nvTHIuoud6e1SfrUaFwHqT0i4b5Nr+1rPWVds3B5+50= github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4/v4 v4.1.6 h1:ueMTcBBFrbT8K4uGDNNZPa8Z7LtPV7Cl0TDjaeHxP44= diff --git a/server/analyzer/resolve_type.go b/server/analyzer/resolve_type.go index 7e6214a7c9..dd80e8004c 100644 --- a/server/analyzer/resolve_type.go +++ b/server/analyzer/resolve_type.go @@ -20,6 +20,8 @@ import ( "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/transform" + pgnodes "github.com/dolthub/doltgresql/server/node" + "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/expression" @@ -47,6 +49,32 @@ func ResolveTypeForNodes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, return transform.Node(node, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) { var same = transform.SameTree switch n := node.(type) { + case *plan.AddColumn: + col := n.Column() + if rt, ok := col.Type.(*pgtypes.DoltgresType); ok && !rt.IsResolvedType() { + dt, err := resolveType(ctx, rt) + if err != nil { + return nil, transform.NewTree, err + } + same = transform.NewTree + col.Type = dt + } + return node, same, nil + case *pgnodes.CreateFunction: + retType, err := resolveType(ctx, n.ReturnType) + if err != nil { + return nil, transform.NewTree, err + } + paramTypes := make([]*pgtypes.DoltgresType, len(n.ParameterTypes)) + for i := range n.ParameterTypes { + paramTypes[i], err = resolveType(ctx, n.ParameterTypes[i]) + if err != nil { + return nil, transform.NewTree, err + } + } + n.ReturnType = retType + n.ParameterTypes = paramTypes + return node, transform.NewTree, nil case *plan.CreateTable: for _, col := range n.TargetSchema() { if rt, ok := col.Type.(*pgtypes.DoltgresType); ok && !rt.IsResolvedType() { @@ -59,17 +87,6 @@ func ResolveTypeForNodes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, } } return node, same, nil - case *plan.AddColumn: - col := n.Column() - if rt, ok := col.Type.(*pgtypes.DoltgresType); ok && !rt.IsResolvedType() { - dt, err := resolveType(ctx, rt) - if err != nil { - return nil, transform.NewTree, err - } - same = transform.NewTree - col.Type = dt - } - return node, same, nil case *plan.ModifyColumn: col := n.NewColumn() if rt, ok := col.Type.(*pgtypes.DoltgresType); ok && !rt.IsResolvedType() { @@ -116,6 +133,9 @@ func ResolveTypeForExprs(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, // resolveType resolves any type that is unresolved yet. (e.g.: domain types, built-in types that schema specified, etc.) func resolveType(ctx *sql.Context, typ *pgtypes.DoltgresType) (*pgtypes.DoltgresType, error) { + if typ.IsResolvedType() { + return typ, nil + } schema, err := core.GetSchemaName(ctx, nil, typ.Schema()) if err != nil { return nil, err @@ -126,6 +146,13 @@ func resolveType(ctx *sql.Context, typ *pgtypes.DoltgresType) (*pgtypes.Doltgres } resolvedTyp, exists := typs.GetType(id.NewType(schema, typ.Name())) if !exists { + // If a blank schema is provided, then we'll also try the pg_catalog, since a type is most likely to be there + if typ.Schema() == "" { + resolvedTyp, exists = typs.GetType(id.NewType("pg_catalog", typ.Name())) + if exists { + return resolvedTyp, nil + } + } return nil, pgtypes.ErrTypeDoesNotExist.New(typ.Name()) } return resolvedTyp, nil diff --git a/server/ast/alter_function.go b/server/ast/alter_function.go index b1d7759d87..e006cadf42 100644 --- a/server/ast/alter_function.go +++ b/server/ast/alter_function.go @@ -22,7 +22,7 @@ import ( // nodeAlterFunction handles *tree.AlterFunction nodes. func nodeAlterFunction(ctx *Context, node *tree.AlterFunction) (vitess.Statement, error) { - err := verifyRedundantRoutineOption(ctx, node.Options) + _, err := validateRoutineOptions(ctx, node.Options) if err != nil { return nil, err } diff --git a/server/ast/alter_procedure.go b/server/ast/alter_procedure.go index 0f372d4733..edabbced1d 100644 --- a/server/ast/alter_procedure.go +++ b/server/ast/alter_procedure.go @@ -22,7 +22,7 @@ import ( // nodeAlterProcedure handles *tree.AlterProcedure nodes. func nodeAlterProcedure(ctx *Context, node *tree.AlterProcedure) (vitess.Statement, error) { - err := verifyRedundantRoutineOption(ctx, node.Options) + _, err := validateRoutineOptions(ctx, node.Options) if err != nil { return nil, err } diff --git a/server/ast/context.go b/server/ast/context.go index 5c6c010cbd..35c5e22524 100644 --- a/server/ast/context.go +++ b/server/ast/context.go @@ -14,19 +14,24 @@ package ast -import "github.com/dolthub/doltgresql/server/auth" +import ( + "github.com/dolthub/doltgresql/postgres/parser/parser" + "github.com/dolthub/doltgresql/server/auth" +) // Context contains any relevant context for the AST conversion. For example, the auth system uses the context to // determine which larger statement an expression exists in, which may influence how the expression should handle // authorization. type Context struct { - authContext *auth.AuthContext + authContext *auth.AuthContext + originalQuery string } // NewContext returns a new *Context. -func NewContext() *Context { +func NewContext(postgresStmt parser.Statement) *Context { return &Context{ - authContext: auth.NewAuthContext(), + authContext: auth.NewAuthContext(), + originalQuery: postgresStmt.SQL, } } diff --git a/server/ast/convert.go b/server/ast/convert.go index 7aadccf19d..d84fa5b3fc 100644 --- a/server/ast/convert.go +++ b/server/ast/convert.go @@ -25,7 +25,7 @@ import ( // Convert converts a Postgres AST into a Vitess AST. func Convert(postgresStmt parser.Statement) (vitess.Statement, error) { - ctx := NewContext() + ctx := NewContext(postgresStmt) switch stmt := postgresStmt.AST.(type) { case *tree.AlterAggregate: return nodeAlterAggregate(ctx, stmt) diff --git a/server/ast/create_function.go b/server/ast/create_function.go index ff22fc90c7..d824f2aff7 100644 --- a/server/ast/create_function.go +++ b/server/ast/create_function.go @@ -15,33 +15,65 @@ package ast import ( - "github.com/cockroachdb/errors" + "strings" + "github.com/cockroachdb/errors" vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" + pgnodes "github.com/dolthub/doltgresql/server/node" + "github.com/dolthub/doltgresql/server/plpgsql" + pgtypes "github.com/dolthub/doltgresql/server/types" ) // nodeCreateFunction handles *tree.CreateFunction nodes. func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Statement, error) { - err := verifyRedundantRoutineOption(ctx, node.Options) + options, err := validateRoutineOptions(ctx, node.Options) if err != nil { return nil, err } - - return NotYetSupportedError("CREATE FUNCTION statement is not yet supported") -} - -// verifyRedundantRoutineOption checks for each option defined only once. -// If there is multiple definition of the same option, it returns an error. -func verifyRedundantRoutineOption(ctx *Context, options []tree.RoutineOption) error { - var optDefined = make(map[tree.FunctionOption]struct{}) - for _, opt := range options { - if _, ok := optDefined[opt.OptionType]; ok { - return errors.Errorf("ERROR: conflicting or redundant options") - } else { - optDefined[opt.OptionType] = struct{}{} + // We only support PL/pgSQL for now, so we'll verify that first + if languageOption, ok := options[tree.OptionLanguage]; ok { + if strings.ToLower(languageOption.Language) != "plpgsql" { + return nil, errors.Errorf("CREATE FUNCTION only supports PL/pgSQL for now") } + } else { + return nil, errors.Errorf("CREATE FUNCTION does not define an input language") + } + // PL/pgSQL is different from standard Postgres SQL, so we have to use a special parser to handle it. + // This parser also requires the full `CREATE FUNCTION` string, so we'll pass that. + parsedBody, err := plpgsql.Parse(ctx.originalQuery) + if err != nil { + return nil, err + } + // Grab the rest of the information that we'll need to create the function + tableName := node.Name.ToTableName() + schemaName := tableName.Schema() + if len(schemaName) == 0 { + // TODO: fix function finder such that it doesn't always assume pg_catalog + schemaName = "pg_catalog" + } + retType := pgtypes.Void + if len(node.RetType) == 1 { + retType = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(node.RetType[0].Type.SQLString())) + } + paramNames := make([]string, len(node.Args)) + paramTypes := make([]*pgtypes.DoltgresType, len(node.Args)) + for i, arg := range node.Args { + paramNames[i] = arg.Name.String() + paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(arg.Type.SQLString())) } - return nil + // Returns the stored procedure call with all options + return vitess.InjectedStatement{ + Statement: pgnodes.NewCreateFunction( + tableName.Table(), + schemaName, + retType, + paramNames, + paramTypes, + true, // TODO: implement strict check + parsedBody, + ), + Children: nil, + }, nil } diff --git a/server/ast/create_procedure.go b/server/ast/create_procedure.go index 2c0b25fef8..b8ce9a6b9b 100644 --- a/server/ast/create_procedure.go +++ b/server/ast/create_procedure.go @@ -15,6 +15,8 @@ package ast import ( + "github.com/cockroachdb/errors" + vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" @@ -22,10 +24,24 @@ import ( // nodeCreateProcedure handles *tree.CreateProcedure nodes. func nodeCreateProcedure(ctx *Context, node *tree.CreateProcedure) (vitess.Statement, error) { - err := verifyRedundantRoutineOption(ctx, node.Options) + _, err := validateRoutineOptions(ctx, node.Options) if err != nil { return nil, err } return NotYetSupportedError("CREATE PROCEDURE statement is not yet supported") } + +// validateRoutineOptions ensures that each option is defined only once. Returns a map containing all options, or an +// error if an option is invalid or is defined multiple times. +func validateRoutineOptions(ctx *Context, options []tree.RoutineOption) (map[tree.FunctionOption]tree.RoutineOption, error) { + var optDefined = make(map[tree.FunctionOption]tree.RoutineOption) + for _, opt := range options { + if _, ok := optDefined[opt.OptionType]; ok { + return nil, errors.Errorf("ERROR: conflicting or redundant options") + } else { + optDefined[opt.OptionType] = opt + } + } + return optDefined, nil +} diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index a031811a73..6138a1ce45 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -34,7 +34,12 @@ var initializedFunctions = false // from within an init(). func RegisterFunction(f FunctionInterface) { if initializedFunctions { - panic("attempted to register a function after the init() phase") + // TODO: this should be able to handle overloads + name := strings.ToLower(f.GetName()) + if err := validateFunction(name, []FunctionInterface{f}); err != nil { + panic(err) // TODO: replace panics here with errors + } + compileNonOperatorFunction(name, []FunctionInterface{f}) } switch f := f.(type) { case Function0: @@ -76,7 +81,7 @@ func Initialize() { compileFunctions() } -// replaceGmsBuiltIns replaces all GMS built-ins that have conflicting names with PostgreSQL functions +// replaceGmsBuiltIns replaces all GMS built-ins that have conflicting names with PostgreSQL functions. func replaceGmsBuiltIns() { functionNames := make(map[string]struct{}) for name := range Catalog { @@ -91,59 +96,71 @@ func replaceGmsBuiltIns() { function.BuiltIns = newBuiltIns } -// validateFunctions panics if any functions are defined incorrectly or ambiguously +// validateFunctions panics if any functions are defined incorrectly or ambiguously. func validateFunctions() { for funcName, overloads := range Catalog { - funcName := funcName - // Verify that each function uses the correct Function overload - for _, functionOverload := range overloads { - if functionOverload.GetExpectedParameterCount() >= 0 && - len(functionOverload.GetParameters()) != functionOverload.GetExpectedParameterCount() { - panic(errors.Errorf("function `%s` should have %d arguments but has %d arguments", - funcName, functionOverload.GetExpectedParameterCount(), len(functionOverload.GetParameters()))) - } + if err := validateFunction(funcName, overloads); err != nil { + panic(err) } - // Verify that all overloads are unique - for functionIndex, f1 := range overloads { - for _, f2 := range overloads[functionIndex+1:] { - sameCount := 0 - if f1.GetExpectedParameterCount() == f2.GetExpectedParameterCount() { - f2Parameters := f2.GetParameters() - for parameterIndex, f1Parameter := range f1.GetParameters() { - if f1Parameter.Equals(f2Parameters[parameterIndex]) { - sameCount++ - } + } +} + +// validateFunction validates whether functions are defined incorrectly or ambiguously. +func validateFunction(funcName string, overloads []FunctionInterface) error { + // Verify that each function uses the correct Function overload + for _, functionOverload := range overloads { + if functionOverload.GetExpectedParameterCount() >= 0 && + len(functionOverload.GetParameters()) != functionOverload.GetExpectedParameterCount() { + return errors.Errorf("function `%s` should have %d arguments but has %d arguments", + funcName, functionOverload.GetExpectedParameterCount(), len(functionOverload.GetParameters())) + } + } + // Verify that all overloads are unique + for functionIndex, f1 := range overloads { + for _, f2 := range overloads[functionIndex+1:] { + sameCount := 0 + if f1.GetExpectedParameterCount() == f2.GetExpectedParameterCount() { + f2Parameters := f2.GetParameters() + for parameterIndex, f1Parameter := range f1.GetParameters() { + if f1Parameter.Equals(f2Parameters[parameterIndex]) { + sameCount++ } } - if sameCount == f1.GetExpectedParameterCount() && f1.GetExpectedParameterCount() > 0 { - panic(errors.Errorf("duplicate function overloads on `%s`", funcName)) - } + } + if sameCount == f1.GetExpectedParameterCount() && f1.GetExpectedParameterCount() > 0 { + return errors.Errorf("duplicate function overloads on `%s`", funcName) } } } + return nil } -// compileFunctions creates a CompiledFunction for each overload of each function in the catalog -func compileFunctions() { - for funcName, overloads := range Catalog { - overloadTree := NewOverloads() - for _, functionOverload := range overloads { - if err := overloadTree.Add(functionOverload); err != nil { - panic(err) - } +// compileNonOperatorFunction creates a CompiledFunction for each overload of the given function. +func compileNonOperatorFunction(funcName string, overloads []FunctionInterface) { + overloadTree := NewOverloads() + for _, functionOverload := range overloads { + if err := overloadTree.Add(functionOverload); err != nil { + panic(err) } + } - // Store the compiled function into the engine's built-in functions - // TODO: don't do this, use an actual contract for communicating these functions to the engine catalog - createFunc := func(params ...sql.Expression) (sql.Expression, error) { - return NewCompiledFunction(funcName, params, overloadTree, false), nil - } - function.BuiltIns = append(function.BuiltIns, sql.FunctionN{ - Name: funcName, - Fn: createFunc, - }) - compiledCatalog[funcName] = createFunc - namedCatalog[funcName] = overloads + // Store the compiled function into the engine's built-in functions + // TODO: don't do this, use an actual contract for communicating these functions to the engine catalog + createFunc := func(params ...sql.Expression) (sql.Expression, error) { + return NewCompiledFunction(funcName, params, overloadTree, false), nil + } + function.BuiltIns = append(function.BuiltIns, sql.FunctionN{ + Name: funcName, + Fn: createFunc, + }) + compiledCatalog[funcName] = createFunc + namedCatalog[funcName] = overloads +} + +// compileFunctions creates a CompiledFunction for each overload of each function in the catalog. +func compileFunctions() { + for funcName, overloads := range Catalog { + compileNonOperatorFunction(funcName, overloads) } // Build the overload for all unary and binary functions based on their operator. This will be used for fallback if diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index ee684fbe56..55b17c3b82 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression" "gopkg.in/src-d/go-errors.v1" + "github.com/dolthub/doltgresql/server/plpgsql" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -304,7 +305,7 @@ func (c *CompiledFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, err case Function4: return f.Callable(ctx, ([5]*pgtypes.DoltgresType)(c.callResolved), args[0], args[1], args[2], args[3]) case InterpretedFunction: - return f.Call(ctx, c.runner, c.callResolved, args) + return plpgsql.Call(ctx, f, c.runner, c.callResolved, args) default: return nil, cerrors.Errorf("unknown function type in CompiledFunction::Eval") } diff --git a/server/functions/framework/interpreter.go b/server/functions/framework/interpreted_function.go similarity index 81% rename from server/functions/framework/interpreter.go rename to server/functions/framework/interpreted_function.go index 6a681a17e7..e606773f75 100644 --- a/server/functions/framework/interpreter.go +++ b/server/functions/framework/interpreted_function.go @@ -15,15 +15,16 @@ package framework import ( - "errors" "fmt" "strconv" "strings" + "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" "github.com/lib/pq" "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/server/plpgsql" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -37,11 +38,11 @@ type InterpretedFunction struct { Variadic bool IsNonDeterministic bool Strict bool - Labels map[string]int - Statements []InterpreterOperation + Statements []plpgsql.InterpreterOperation } var _ FunctionInterface = InterpretedFunction{} +var _ plpgsql.InterpretedFunction = InterpretedFunction{} // GetExpectedParameterCount implements the interface FunctionInterface. func (iFunc InterpretedFunction) GetExpectedParameterCount() int { @@ -58,11 +59,21 @@ func (iFunc InterpretedFunction) GetParameters() []*pgtypes.DoltgresType { return iFunc.ParameterTypes } +// GetParameterNames returns the names of all parameters. +func (iFunc InterpretedFunction) GetParameterNames() []string { + return iFunc.ParameterNames +} + // GetReturn implements the interface FunctionInterface. func (iFunc InterpretedFunction) GetReturn() *pgtypes.DoltgresType { return iFunc.ReturnType } +// GetStatements returns the contained statements. +func (iFunc InterpretedFunction) GetStatements() []plpgsql.InterpreterOperation { + return iFunc.Statements +} + // InternalID implements the interface FunctionInterface. func (iFunc InterpretedFunction) InternalID() id.Id { return iFunc.ID.AsId() @@ -89,8 +100,8 @@ func (iFunc InterpretedFunction) VariadicIndex() int { return -1 } -// querySingleReturn handles queries that are supposed to return a single value. -func (iFunc InterpretedFunction) querySingleReturn(ctx *sql.Context, stack InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error) { +// QuerySingleReturn handles queries that are supposed to return a single value. +func (InterpretedFunction) QuerySingleReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error) { if len(bindings) > 0 { for i, bindingName := range bindings { variable := stack.GetVariable(bindingName) @@ -108,7 +119,7 @@ func (iFunc InterpretedFunction) querySingleReturn(ctx *sql.Context, stack Inter stmt = strings.Replace(stmt, "$"+strconv.Itoa(i+1), formattedVar, 1) } } - sch, rowIter, _, err := stack.runner.QueryWithBindings(ctx, stmt, nil, nil, nil) + sch, rowIter, _, err := stack.Runner().QueryWithBindings(ctx, stmt, nil, nil, nil) if err != nil { return nil, err } @@ -143,8 +154,8 @@ func (iFunc InterpretedFunction) querySingleReturn(ctx *sql.Context, stack Inter return castFunc(ctx, rows[0][0], targetType) } -// queryMultiReturn handles queries that may return multiple values over multiple rows. -func (iFunc InterpretedFunction) queryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (rowIter sql.RowIter, err error) { +// QueryMultiReturn handles queries that may return multiple values over multiple rows. +func (InterpretedFunction) QueryMultiReturn(ctx *sql.Context, stack plpgsql.InterpreterStack, stmt string, bindings []string) (rowIter sql.RowIter, err error) { if len(bindings) > 0 { for i, bindingName := range bindings { variable := stack.GetVariable(bindingName) @@ -162,7 +173,7 @@ func (iFunc InterpretedFunction) queryMultiReturn(ctx *sql.Context, stack Interp stmt = strings.Replace(stmt, "$"+strconv.Itoa(i+1), formattedVar, 1) } } - _, rowIter, _, err = stack.runner.QueryWithBindings(ctx, stmt, nil, nil, nil) + _, rowIter, _, err = stack.Runner().QueryWithBindings(ctx, stmt, nil, nil, nil) return rowIter, err } diff --git a/server/functions/init.go b/server/functions/init.go index 8489b42a75..a80470be4a 100644 --- a/server/functions/init.go +++ b/server/functions/init.go @@ -104,7 +104,6 @@ func Init() { initGcd() initGenRandomUuid() initInitcap() - initInterpretedExamples() initLcm() initLeft() initLength() diff --git a/server/functions/interpreted_examples.go b/server/functions/interpreted_examples.go deleted file mode 100644 index 549add03b8..0000000000 --- a/server/functions/interpreted_examples.go +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright 2025 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package functions - -import ( - "github.com/dolthub/doltgresql/core/id" - "github.com/dolthub/doltgresql/server/functions/framework" - pgtypes "github.com/dolthub/doltgresql/server/types" -) - -// initInterpretedExamples registers example functions to the catalog. These are temporary, and exist solely to test the -// interpreter functionality. -func initInterpretedExamples() { - framework.RegisterFunction(interpretedAssignment) - framework.RegisterFunction(interpretedAlias) -} - -// interpretedAssignment is roughly equivalent to the (expected) parsed output of the following function definition: -/* CREATE FUNCTION interpreted_assignment(input TEXT) RETURNS TEXT AS $$ -* DECLARE -* var1 TEXT; -* BEGIN -* var1 := 'Initial: ' || input; -* IF input = 'Hello' THEN -* var1 := var1 || ' - Greeting'; -* ELSIF input = 'Bye' THEN -* var1 := var1 || ' - Farewell'; -* ELSIF length(input) > 5 THEN -* var1 := var1 || ' - Over 5'; -* ELSE -* var1 := var1 || ' - Else'; -* END IF; -* RETURN var1; -* END; -* $$ LANGUAGE plpgsql; - */ -var interpretedAssignment = framework.InterpretedFunction{ - ID: id.NewFunction("pg_catalog", "interpreted_assignment", pgtypes.Text.ID), - ReturnType: pgtypes.Text, - ParameterNames: []string{"input"}, - ParameterTypes: []*pgtypes.DoltgresType{pgtypes.Text}, - Variadic: false, - IsNonDeterministic: false, - Strict: true, - Labels: nil, - Statements: []framework.InterpreterOperation{ - { // 0 - OpCode: framework.OpCode_ScopeBegin, - }, - { // 1 - OpCode: framework.OpCode_Declare, - PrimaryData: `text`, - Target: `var1`, - }, - { // 2 - OpCode: framework.OpCode_Assign, - PrimaryData: `SELECT 'Initial: ' || $1;`, - SecondaryData: []string{`input`}, - Target: `var1`, - }, - { // 3 - OpCode: framework.OpCode_ScopeBegin, - }, - { // 4 - OpCode: framework.OpCode_If, - PrimaryData: `SELECT $1 = 'Hello';`, - SecondaryData: []string{`input`}, - Index: 6, - }, - { // 5 - OpCode: framework.OpCode_Goto, - Index: 8, - }, - { // 6 - OpCode: framework.OpCode_Assign, - PrimaryData: `SELECT $1 || ' - Greeting';`, - SecondaryData: []string{`var1`}, - Target: `var1`, - }, - { // 7 - OpCode: framework.OpCode_Goto, - Index: 17, - }, - { // 8 - OpCode: framework.OpCode_If, - PrimaryData: `SELECT $1 = 'Bye';`, - SecondaryData: []string{`input`}, - Index: 10, - }, - { // 9 - OpCode: framework.OpCode_Goto, - Index: 12, - }, - { // 10 - OpCode: framework.OpCode_Assign, - PrimaryData: `SELECT $1 || ' - Farewell';`, - SecondaryData: []string{`var1`}, - Target: `var1`, - }, - { // 11 - OpCode: framework.OpCode_Goto, - Index: 17, - }, - { // 12 - OpCode: framework.OpCode_If, - PrimaryData: `SELECT length($1) > 5;`, - SecondaryData: []string{`input`}, - Index: 14, - }, - { // 13 - OpCode: framework.OpCode_Goto, - Index: 16, - }, - { // 14 - OpCode: framework.OpCode_Assign, - PrimaryData: `SELECT $1 || ' - Over 5';`, - SecondaryData: []string{`var1`}, - Target: `var1`, - }, - { // 15 - OpCode: framework.OpCode_Goto, - Index: 17, - }, - { // 16 - OpCode: framework.OpCode_Assign, - PrimaryData: `SELECT $1 || ' - Else';`, - SecondaryData: []string{`var1`}, - Target: `var1`, - }, - { // 17 - OpCode: framework.OpCode_ScopeEnd, - }, - { // 18 - OpCode: framework.OpCode_Return, - PrimaryData: `SELECT $1;`, - SecondaryData: []string{`var1`}, - }, - { // 19 - OpCode: framework.OpCode_ScopeEnd, - }, - }, -} - -// interpretedAlias is roughly equivalent to the (expected) parsed output of the following function definition: -/* -CREATE FUNCTION interpreted_alias(input TEXT) -RETURNS TEXT AS $$ -DECLARE - var1 TEXT; - var2 TEXT; -BEGIN - DECLARE - alias1 ALIAS FOR var1; - alias2 ALIAS FOR alias1; - alias3 ALIAS FOR input; - BEGIN - alias2 := alias3; - END; - RETURN var1; -END; -$$ LANGUAGE plpgsql; -*/ -var interpretedAlias = framework.InterpretedFunction{ - ID: id.NewFunction("pg_catalog", "interpreted_alias", pgtypes.Text.ID), - ReturnType: pgtypes.Text, - ParameterNames: []string{"input"}, - ParameterTypes: []*pgtypes.DoltgresType{pgtypes.Text}, - Variadic: false, - IsNonDeterministic: false, - Strict: true, - Labels: nil, - Statements: []framework.InterpreterOperation{ - { // 0 - OpCode: framework.OpCode_ScopeBegin, - }, - { // 1 - OpCode: framework.OpCode_Declare, - Target: `var1`, - PrimaryData: `text`, - }, - { // 2 - OpCode: framework.OpCode_Declare, - Target: `var2`, - PrimaryData: `text`, - }, - { // 3 - OpCode: framework.OpCode_ScopeBegin, - }, - { // 4 - OpCode: framework.OpCode_Alias, - Target: `alias1`, - PrimaryData: `var1`, - }, - { // 5 - OpCode: framework.OpCode_Alias, - Target: `alias2`, - PrimaryData: `alias1`, - }, - { // 6 - OpCode: framework.OpCode_Alias, - Target: `alias3`, - PrimaryData: `input`, - }, - { // 7 - OpCode: framework.OpCode_ScopeBegin, - }, - { // 8 - OpCode: framework.OpCode_Assign, - PrimaryData: `SELECT $1;`, - SecondaryData: []string{`alias3`}, - Target: `alias2`, - }, - { // 9 - OpCode: framework.OpCode_ScopeEnd, - }, - { // 10 - OpCode: framework.OpCode_Return, - PrimaryData: `SELECT $1;`, - SecondaryData: []string{`var1`}, - }, - { // 11 - OpCode: framework.OpCode_ScopeEnd, - }, - }, -} diff --git a/server/node/create_function.go b/server/node/create_function.go new file mode 100644 index 0000000000..846c1a6218 --- /dev/null +++ b/server/node/create_function.go @@ -0,0 +1,118 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package node + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" + vitess "github.com/dolthub/vitess/go/vt/sqlparser" + + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/server/functions/framework" + "github.com/dolthub/doltgresql/server/plpgsql" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// CreateFunction implements CREATE FUNCTION. +type CreateFunction struct { + FunctionName string + SchemaName string + ReturnType *pgtypes.DoltgresType + ParameterNames []string + ParameterTypes []*pgtypes.DoltgresType + Strict bool + Statements []plpgsql.InterpreterOperation +} + +var _ sql.ExecSourceRel = (*CreateFunction)(nil) +var _ vitess.Injectable = (*CreateFunction)(nil) + +// NewCreateFunction returns a new *CreateFunction. +func NewCreateFunction( + functionName string, + schemaName string, + retType *pgtypes.DoltgresType, + paramNames []string, + paramTypes []*pgtypes.DoltgresType, + strict bool, + statements []plpgsql.InterpreterOperation) *CreateFunction { + return &CreateFunction{ + FunctionName: functionName, + SchemaName: schemaName, + ReturnType: retType, + ParameterNames: paramNames, + ParameterTypes: paramTypes, + Strict: strict, + Statements: statements, + } +} + +// Children implements the interface sql.ExecSourceRel. +func (c *CreateFunction) Children() []sql.Node { + return nil +} + +// IsReadOnly implements the interface sql.ExecSourceRel. +func (c *CreateFunction) IsReadOnly() bool { + return false +} + +// Resolved implements the interface sql.ExecSourceRel. +func (c *CreateFunction) Resolved() bool { + return true +} + +// RowIter implements the interface sql.ExecSourceRel. +func (c *CreateFunction) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { + idTypes := make([]id.Type, len(c.ParameterTypes)) + for i, typ := range c.ParameterTypes { + idTypes[i] = typ.ID + } + framework.RegisterFunction(framework.InterpretedFunction{ + ID: id.NewFunction(c.SchemaName, c.FunctionName, idTypes...), + ReturnType: c.ReturnType, + ParameterNames: c.ParameterNames, + ParameterTypes: c.ParameterTypes, + Variadic: false, // TODO: implement this + IsNonDeterministic: true, + Strict: c.Strict, + Statements: c.Statements, + }) + return sql.RowsToRowIter(), nil +} + +// Schema implements the interface sql.ExecSourceRel. +func (c *CreateFunction) Schema() sql.Schema { + return nil +} + +// String implements the interface sql.ExecSourceRel. +func (c *CreateFunction) String() string { + // TODO: fully implement this + return "CREATE FUNCTION" +} + +// WithChildren implements the interface sql.ExecSourceRel. +func (c *CreateFunction) WithChildren(children ...sql.Node) (sql.Node, error) { + return plan.NillaryWithChildren(c, children...) +} + +// WithResolvedChildren implements the interface vitess.Injectable. +func (c *CreateFunction) WithResolvedChildren(children []any) (any, error) { + if len(children) != 0 { + return nil, ErrVitessChildCount.New(0, len(children)) + } + return c, nil +} diff --git a/server/functions/framework/interpreter_logic.go b/server/plpgsql/interpreter_logic.go similarity index 71% rename from server/functions/framework/interpreter_logic.go rename to server/plpgsql/interpreter_logic.go index a685674fc1..368c20dfca 100644 --- a/server/functions/framework/interpreter_logic.go +++ b/server/plpgsql/interpreter_logic.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package framework +package plpgsql import ( "fmt" @@ -25,28 +25,42 @@ import ( pgtypes "github.com/dolthub/doltgresql/server/types" ) +// InterpretedFunction is an interface that essentially mirrors the implementation of InterpretedFunction in the +// framework package. +type InterpretedFunction interface { + GetParameters() []*pgtypes.DoltgresType + GetParameterNames() []string + GetReturn() *pgtypes.DoltgresType + GetStatements() []InterpreterOperation + QueryMultiReturn(ctx *sql.Context, stack InterpreterStack, stmt string, bindings []string) (rowIter sql.RowIter, err error) + QuerySingleReturn(ctx *sql.Context, stack InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error) +} + // Call runs the contained operations on the given runner. -func (iFunc InterpretedFunction) Call(ctx *sql.Context, runner analyzer.StatementRunner, paramsAndReturn []*pgtypes.DoltgresType, vals []any) (any, error) { +func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.StatementRunner, paramsAndReturn []*pgtypes.DoltgresType, vals []any) (any, error) { // Set up the initial state of the function counter := -1 // We increment before accessing, so start at -1 stack := NewInterpreterStack(runner) // Add the parameters - if len(vals) != len(iFunc.ParameterTypes) { - return nil, fmt.Errorf("parameter count mismatch: expected `%d` got %d`", len(iFunc.ParameterTypes), len(vals)) + parameterTypes := iFunc.GetParameters() + parameterNames := iFunc.GetParameterNames() + if len(vals) != len(parameterTypes) { + return nil, fmt.Errorf("parameter count mismatch: expected %d got %d", len(parameterTypes), len(vals)) } for i := range vals { - stack.NewVariableWithValue(iFunc.ParameterNames[i], iFunc.ParameterTypes[i], vals[i]) + stack.NewVariableWithValue(parameterNames[i], parameterTypes[i], vals[i]) } // Run the statements + statements := iFunc.GetStatements() for { counter++ - if counter >= len(iFunc.Statements) { + if counter >= len(statements) { break } else if counter < 0 { panic("negative function counter") } - operation := iFunc.Statements[counter] + operation := statements[counter] switch operation.OpCode { case OpCode_Alias: iv := stack.GetVariable(operation.PrimaryData) @@ -59,7 +73,7 @@ func (iFunc InterpretedFunction) Call(ctx *sql.Context, runner analyzer.Statemen if iv == nil { return nil, fmt.Errorf("variable `%s` could not be found", operation.Target) } - retVal, err := iFunc.querySingleReturn(ctx, stack, operation.PrimaryData, iv.Type, operation.SecondaryData) + retVal, err := iFunc.QuerySingleReturn(ctx, stack, operation.PrimaryData, iv.Type, operation.SecondaryData) if err != nil { return nil, err } @@ -84,7 +98,7 @@ func (iFunc InterpretedFunction) Call(ctx *sql.Context, runner analyzer.Statemen case OpCode_Exception: // TODO: implement case OpCode_Execute: - rowIter, err := iFunc.queryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) + rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) if err != nil { return nil, err } @@ -101,7 +115,7 @@ func (iFunc InterpretedFunction) Call(ctx *sql.Context, runner analyzer.Statemen // We must compare to the index - 1, so that the increment hits our target if counter <= operation.Index { for ; counter < operation.Index-1; counter++ { - switch iFunc.Statements[counter].OpCode { + switch statements[counter].OpCode { case OpCode_ScopeBegin: stack.PushScope() case OpCode_ScopeEnd: @@ -110,7 +124,7 @@ func (iFunc InterpretedFunction) Call(ctx *sql.Context, runner analyzer.Statemen } } else { for ; counter > operation.Index-1; counter-- { - switch iFunc.Statements[counter].OpCode { + switch statements[counter].OpCode { case OpCode_ScopeBegin: stack.PopScope() case OpCode_ScopeEnd: @@ -119,7 +133,7 @@ func (iFunc InterpretedFunction) Call(ctx *sql.Context, runner analyzer.Statemen } } case OpCode_If: - retVal, err := iFunc.querySingleReturn(ctx, stack, operation.PrimaryData, pgtypes.Bool, operation.SecondaryData) + retVal, err := iFunc.QuerySingleReturn(ctx, stack, operation.PrimaryData, pgtypes.Bool, operation.SecondaryData) if err != nil { return nil, err } @@ -133,7 +147,7 @@ func (iFunc InterpretedFunction) Call(ctx *sql.Context, runner analyzer.Statemen case OpCode_Loop: // TODO: implement case OpCode_Perform: - rowIter, err := iFunc.queryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) + rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) if err != nil { return nil, err } @@ -141,7 +155,7 @@ func (iFunc InterpretedFunction) Call(ctx *sql.Context, runner analyzer.Statemen return nil, err } case OpCode_Query: - rowIter, err := iFunc.queryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) + rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) if err != nil { return nil, err } @@ -152,7 +166,7 @@ func (iFunc InterpretedFunction) Call(ctx *sql.Context, runner analyzer.Statemen if len(operation.PrimaryData) == 0 { return nil, nil } - return iFunc.querySingleReturn(ctx, stack, operation.PrimaryData, iFunc.ReturnType, operation.SecondaryData) + return iFunc.QuerySingleReturn(ctx, stack, operation.PrimaryData, iFunc.GetReturn(), operation.SecondaryData) case OpCode_ScopeBegin: stack.PushScope() case OpCode_ScopeEnd: diff --git a/server/functions/framework/interpreter_operation.go b/server/plpgsql/interpreter_operation.go similarity index 99% rename from server/functions/framework/interpreter_operation.go rename to server/plpgsql/interpreter_operation.go index 1a1ff17cfb..093bd60a08 100644 --- a/server/functions/framework/interpreter_operation.go +++ b/server/plpgsql/interpreter_operation.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package framework +package plpgsql // OpCode states the operation to be performed. Most operations have a direct analogue to a Pl/pgSQL operation, however // some exist only in Doltgres (specific to our interpreter implementation). diff --git a/server/functions/framework/interpreter_stack.go b/server/plpgsql/interpreter_stack.go similarity index 89% rename from server/functions/framework/interpreter_stack.go rename to server/plpgsql/interpreter_stack.go index 2bfc92e572..f5d1eb70f3 100644 --- a/server/functions/framework/interpreter_stack.go +++ b/server/plpgsql/interpreter_stack.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package framework +package plpgsql import ( "fmt" @@ -61,6 +61,11 @@ func (is *InterpreterStack) Details() *InterpreterScopeDetails { return is.stack.Peek() } +// Runner returns the runner that is being used for the function's execution. +func (is *InterpreterStack) Runner() analyzer.StatementRunner { + return is.runner +} + // GetVariable traverses the stack (starting from the top) to find a variable with a matching name. Returns nil if no // variable was found. func (is *InterpreterStack) GetVariable(name string) *InterpreterVariable { @@ -72,6 +77,17 @@ func (is *InterpreterStack) GetVariable(name string) *InterpreterVariable { return nil } +// ListVariables returns a map with the names of all variables. +func (is *InterpreterStack) ListVariables() map[string]struct{} { + seen := make(map[string]struct{}) + for i := 0; i < is.stack.Len(); i++ { + for varName := range is.stack.PeekDepth(i).variables { + seen[varName] = struct{}{} + } + } + return seen +} + // NewVariable creates a new variable in the current scope. If a variable with the same name exists in a previous scope, // then that variable will be shadowed until the current scope exits. func (is *InterpreterStack) NewVariable(name string, typ *pgtypes.DoltgresType) { diff --git a/server/plpgsql/json.go b/server/plpgsql/json.go new file mode 100644 index 0000000000..c435c0e11b --- /dev/null +++ b/server/plpgsql/json.go @@ -0,0 +1,206 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plpgsql + +import ( + "strings" + + "github.com/cockroachdb/errors" +) + +// action exists to match the expected JSON format. +type action struct { + StmtBlock plpgSQL_stmt_block `json:"PLpgSQL_stmt_block"` +} + +// cond exists to match the expected JSON format. +type cond struct { + Expression plpgSQL_expr `json:"PLpgSQL_expr"` +} + +// datatype exists to match the expected JSON format. +type datatype struct { + Type plpgSQL_type `json:"PLpgSQL_type"` +} + +// elsif exists to match the expected JSON format. +type elsif struct { + ElseIf plpgSQL_if_elsif `json:"PLpgSQL_if_elsif"` +} + +// expr exists to match the expected JSON format. +type expr struct { + Expression plpgSQL_expr `json:"PLpgSQL_expr"` +} + +// datum exists to match the expected JSON format. +type datum struct { + Variable plpgSQL_var `json:"PLpgSQL_var"` +} + +// function exists to match the expected JSON format. +type function struct { + Function plpgSQL_block `json:"PLpgSQL_function"` +} + +// plpgSQL_block exists to match the expected JSON format. +type plpgSQL_block struct { + Datums []datum `json:"datums"` + Action action `json:"action"` +} + +// plpgSQL_expr exists to match the expected JSON format. +type plpgSQL_expr struct { + Query string `json:"query"` + ParseMode int32 `json:"parseMode"` +} + +// plpgSQL_stmt_assign exists to match the expected JSON format. +type plpgSQL_stmt_assign struct { + Expression expr `json:"expr"` + VariableNumber int32 `json:"varno"` + LineNumber int32 `json:"lineno"` +} + +// plpgSQL_stmt_block exists to match the expected JSON format. +type plpgSQL_stmt_block struct { + Body []statement `json:"body"` + LineNumber int32 `json:"lineno"` +} + +// plpgSQL_if_elsif exists to match the expected JSON format. +type plpgSQL_if_elsif struct { + Condition cond `json:"cond"` + Then []statement `json:"stmts"` + LineNumber int32 `json:"lineno"` +} + +// plpgSQL_stmt_if exists to match the expected JSON format. +type plpgSQL_stmt_if struct { + Condition cond `json:"cond"` + Then []statement `json:"then_body"` + ElseIf []elsif `json:"elsif_list"` + Else []statement `json:"else_body"` + LineNumber int32 `json:"lineno"` +} + +// plpgSQL_stmt_return exists to match the expected JSON format. +type plpgSQL_stmt_return struct { + Expression expr `json:"expr"` + LineNumber int32 `json:"lineno"` +} + +// plpgSQL_type exists to match the expected JSON format. +type plpgSQL_type struct { + Name string `json:"typname"` +} + +// plpgSQL_var exists to match the expected JSON format. +type plpgSQL_var struct { + RefName string `json:"refname"` + Type datatype `json:"datatype"` + LineNumber int32 `json:"lineno"` +} + +// statement exists to match the expected JSON format. Unlike other structs, this is used like a union rather than +// having a singular expected implementation. +type statement struct { + Assignment *plpgSQL_stmt_assign `json:"PLpgSQL_stmt_assign"` + If *plpgSQL_stmt_if `json:"PLpgSQL_stmt_if"` + Return *plpgSQL_stmt_return `json:"PLpgSQL_stmt_return"` +} + +// Convert converts the JSON statement into its output form. +func (stmt *plpgSQL_stmt_assign) Convert() (Assignment, error) { + query := stmt.Expression.Expression.Query + varName := "" + if equalsIdx := strings.Index(query, ":="); equalsIdx > 0 { + varName = strings.TrimSpace(query[:equalsIdx]) + query = strings.TrimSpace(query[equalsIdx+2:]) + } else if equalsIdx = strings.Index(query, "="); equalsIdx > 0 { + varName = strings.TrimSpace(query[:equalsIdx]) + query = strings.TrimSpace(query[equalsIdx+1:]) + } else { + return Assignment{}, errors.New("PL/pgSQL assignment cannot find `:=` sign") + } + return Assignment{ + VariableName: varName, + Expression: query, + VariableIndex: stmt.VariableNumber, + }, nil +} + +// Convert converts the JSON statement into its output form. +func (stmt *plpgSQL_stmt_if) Convert() (Block, error) { + // We store all GOTOs that will need to go to the end of the block. Since we can't know that ahead of time, we store + // their indexes and set them at the end of the function. + var gotoEndIndexes []int32 + returnBlock := Block{ + Body: []Statement{ + If{ + Condition: stmt.Condition.Expression.Query, + GotoOffset: 2, // The operation following the conditional skips the THEN statements, so we're skipping that + }, + }, + } + // We'll parse our THEN statements, but we won't add them to the block just yet as we need their operation sizes + thenStmts, err := jsonConvertStatements(stmt.Then) + if err != nil { + return Block{}, err + } + // When the condition is false, we want to skip our THEN block, so we do that (plus the GOTO which finishes the THEN block) + returnBlock.Body = append(returnBlock.Body, Goto{Offset: OperationSizeForStatements(thenStmts) + 2}) + // Then we'll append our THEN block + returnBlock.Body = append(returnBlock.Body, thenStmts...) + // Then we want to append the GOTO that finishes the THEN block, but we don't know the end just yet, so we'll save + // its index and fill it in later + gotoEndIndexes = append(gotoEndIndexes, int32(len(returnBlock.Body))) + returnBlock.Body = append(returnBlock.Body, Goto{}) + // We repeat the same process for each ELSIF statement (refer to the comments above) + for _, elseIf := range stmt.ElseIf { + returnBlock.Body = append(returnBlock.Body, If{ + Condition: elseIf.ElseIf.Condition.Expression.Query, + GotoOffset: 2, // Same rules as skipping our THEN statement above + }) + elseIfStmts, err := jsonConvertStatements(elseIf.ElseIf.Then) + if err != nil { + return Block{}, err + } + returnBlock.Body = append(returnBlock.Body, Goto{Offset: OperationSizeForStatements(elseIfStmts) + 2}) + returnBlock.Body = append(returnBlock.Body, elseIfStmts...) + gotoEndIndexes = append(gotoEndIndexes, int32(len(returnBlock.Body))) + returnBlock.Body = append(returnBlock.Body, Goto{}) + } + // Finally we handle our ELSE statements. We don't have a condition to check, so we don't have to append any + // additional GOTOs. + elseStmts, err := jsonConvertStatements(stmt.Else) + if err != nil { + return Block{}, err + } + returnBlock.Body = append(returnBlock.Body, elseStmts...) + // Now we'll set all of our GOTOs so that they skip to the end of the block. + // We have to take their index position into account, since we want to skip to the end from their relative position. + for _, gotoEndIndex := range gotoEndIndexes { + returnBlock.Body[gotoEndIndex] = Goto{Offset: int32(len(returnBlock.Body)) - gotoEndIndex} + } + return returnBlock, nil +} + +// Convert converts the JSON statement into its output form. +func (stmt *plpgSQL_stmt_return) Convert() Return { + return Return{ + Expression: stmt.Expression.Expression.Query, + } +} diff --git a/server/plpgsql/json_convert.go b/server/plpgsql/json_convert.go new file mode 100644 index 0000000000..746dc28506 --- /dev/null +++ b/server/plpgsql/json_convert.go @@ -0,0 +1,66 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plpgsql + +import ( + "strings" + + "github.com/cockroachdb/errors" +) + +// jsonConvert handles the conversion from the JSON format into a format that is easier to work with. +func jsonConvert(jsonBlock plpgSQL_block) (Block, error) { + block := Block{} + for _, v := range jsonBlock.Datums { + block.Variable = append(block.Variable, Variable{ + Name: v.Variable.RefName, + Type: strings.ToLower(v.Variable.Type.Type.Name), + IsParameter: v.Variable.LineNumber == 0, + }) + } + var err error + block.Body, err = jsonConvertStatements(jsonBlock.Action.StmtBlock.Body) + if err != nil { + return Block{}, err + } + return block, nil +} + +// jsonConvertStatement converts a statement in JSON form to the output form. +func jsonConvertStatement(stmt statement) (Statement, error) { + switch { + case stmt.Assignment != nil: + return stmt.Assignment.Convert() + case stmt.If != nil: + return stmt.If.Convert() + case stmt.Return != nil: + return stmt.Return.Convert(), nil + default: + return Block{}, errors.Errorf("unhandled statement type: %T", stmt) + } +} + +// jsonConvertStatements converts a collection of statements in JSON form to their output form. +func jsonConvertStatements(stmts []statement) ([]Statement, error) { + newStmts := make([]Statement, len(stmts)) + for i, stmt := range stmts { + var err error + newStmts[i], err = jsonConvertStatement(stmt) + if err != nil { + return nil, err + } + } + return newStmts, nil +} diff --git a/server/plpgsql/parse.go b/server/plpgsql/parse.go new file mode 100644 index 0000000000..095dc07174 --- /dev/null +++ b/server/plpgsql/parse.go @@ -0,0 +1,47 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plpgsql + +import ( + "encoding/json" + + "github.com/cockroachdb/errors" + pg_query "github.com/pganalyze/pg_query_go/v6" +) + +// Parse parses the given CREATE FUNCTION string (which must be the entire string, not just the body) into a Block +// containing the contents of the body. +func Parse(fullCreateFunctionString string) ([]InterpreterOperation, error) { + var functions []function + parsedBody, err := pg_query.ParsePlPgSqlToJSON(fullCreateFunctionString) + if err != nil { + return nil, err + } + err = json.Unmarshal([]byte(parsedBody), &functions) + if err != nil { + return nil, err + } + if len(functions) != 1 { + return nil, errors.New("CREATE FUNCTION parsed multiple blocks") + } + block, err := jsonConvert(functions[0].Function) + if err != nil { + return nil, err + } + ops := make([]InterpreterOperation, 0, len(block.Body)+len(block.Variable)) + stack := NewInterpreterStack(nil) + block.AppendOperations(&ops, &stack) + return ops, nil +} diff --git a/server/plpgsql/statements.go b/server/plpgsql/statements.go new file mode 100644 index 0000000000..5ee6700e92 --- /dev/null +++ b/server/plpgsql/statements.go @@ -0,0 +1,206 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plpgsql + +import ( + "fmt" + "strings" +) + +// Statement represents a PL/pgSQL statement. +type Statement interface { + // OperationSize reports the number of operations that the statement will convert to. + OperationSize() int32 + // AppendOperations adds the statement to the operation slice. + AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) +} + +// Assignment represents an assignment statement. +type Assignment struct { + VariableName string + Expression string + VariableIndex int32 // TODO: figure out what this is used for, probably to get around shadowed variables? +} + +var _ Statement = Assignment{} + +// OperationSize implements the interface Statement. +func (Assignment) OperationSize() int32 { + return 1 +} + +// AppendOperations implements the interface Statement. +func (stmt Assignment) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) { + // TODO: figure out how I'm supposed to actually do this, rather than a search and replace + expression := stmt.Expression + var referencedVariables []string + for varName := range stack.ListVariables() { + if strings.Contains(expression, varName) { + referencedVariables = append(referencedVariables, varName) + } + expression = strings.Replace(expression, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1) + } + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_Assign, + PrimaryData: "SELECT " + expression + ";", + SecondaryData: referencedVariables, + Target: stmt.VariableName, + }) +} + +// Block contains a collection of statements, alongside the variables that were declared for the block. Only the +// top-level block will contain parameter variables. +type Block struct { + Variable []Variable + Body []Statement +} + +var _ Statement = Block{} + +// OperationSize implements the interface Statement. +func (stmt Block) OperationSize() int32 { + total := int32(2) // We start with 2 since we'll have ScopeBegin and ScopeEnd + for _, variable := range stmt.Variable { + if !variable.IsParameter { + total++ + } + } + for _, innerStmt := range stmt.Body { + total += innerStmt.OperationSize() + } + return total +} + +// AppendOperations implements the interface Statement. +func (stmt Block) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) { + stack.PushScope() + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_ScopeBegin, + }) + for _, variable := range stmt.Variable { + if !variable.IsParameter { + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_Declare, + PrimaryData: variable.Type, + Target: variable.Name, + }) + } + stack.NewVariableWithValue(variable.Name, nil, nil) + } + for _, innerStmt := range stmt.Body { + innerStmt.AppendOperations(ops, stack) + } + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_ScopeEnd, + }) + stack.PopScope() +} + +// Goto jumps to the counter at the given offset. +type Goto struct { + Offset int32 +} + +var _ Statement = Goto{} + +// OperationSize implements the interface Statement. +func (Goto) OperationSize() int32 { + return 1 +} + +// AppendOperations implements the interface Statement. +func (stmt Goto) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) { + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_Goto, + Index: len(*ops) + int(stmt.Offset), + }) +} + +// If represents an IF condition, alongside its Goto offset if the condition is true. +type If struct { + Condition string + GotoOffset int32 +} + +var _ Statement = If{} + +// OperationSize implements the interface Statement. +func (If) OperationSize() int32 { + return 1 +} + +// AppendOperations implements the interface Statement. +func (stmt If) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) { + condition := stmt.Condition + var referencedVariables []string + for varName := range stack.ListVariables() { + if strings.Contains(condition, varName) { + referencedVariables = append(referencedVariables, varName) + } + condition = strings.Replace(condition, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1) + } + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_If, + PrimaryData: "SELECT " + condition + ";", + SecondaryData: referencedVariables, + Index: len(*ops) + int(stmt.GotoOffset), + }) +} + +// Return represents a RETURN statement. +type Return struct { + Expression string +} + +var _ Statement = Return{} + +// OperationSize implements the interface Statement. +func (Return) OperationSize() int32 { + return 1 +} + +// AppendOperations implements the interface Statement. +func (stmt Return) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) { + // TODO: figure out how I'm supposed to actually do this, rather than a search and replace + expression := stmt.Expression + var referencedVariables []string + for varName := range stack.ListVariables() { + if strings.Contains(expression, varName) { + referencedVariables = append(referencedVariables, varName) + } + expression = strings.Replace(expression, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1) + } + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_Return, + PrimaryData: "SELECT " + expression + ";", + SecondaryData: referencedVariables, + }) +} + +// Variable represents a variable. These are exclusively found within Block. +type Variable struct { + Name string + Type string + IsParameter bool +} + +// OperationSizeForStatements returns the sum of OperationSize for every statement. +func OperationSizeForStatements(stmts []Statement) int32 { + total := int32(0) + for _, stmt := range stmts { + total += stmt.OperationSize() + } + return total +} diff --git a/testing/go/create_function_test.go b/testing/go/create_function_test.go index 53a6c8f3e1..8fda9466a9 100644 --- a/testing/go/create_function_test.go +++ b/testing/go/create_function_test.go @@ -24,6 +24,24 @@ func TestCreateFunction(t *testing.T) { RunScripts(t, []ScriptTest{ { Name: "Interpreter Assignment Example", + Skip: true, // TODO: need to use a Doltgres function provider, as the current one doesn't allow for adding functions + SetUpScript: []string{`CREATE FUNCTION interpreted_assignment(input TEXT) RETURNS TEXT AS $$ +DECLARE + var1 TEXT; +BEGIN + var1 := 'Initial: ' || input; + IF input = 'Hello' THEN + var1 := var1 || ' - Greeting'; + ELSIF input = 'Bye' THEN + var1 := var1 || ' - Farewell'; + ELSIF length(input) > 5 THEN + var1 := var1 || ' - Over 5'; + ELSE + var1 := var1 || ' - Else'; + END IF; + RETURN var1; +END; +$$ LANGUAGE plpgsql;`}, Assertions: []ScriptTestAssertion{ { Query: "SELECT interpreted_assignment('Hello');", @@ -53,6 +71,28 @@ func TestCreateFunction(t *testing.T) { }, { Name: "Interpreter Alias Example", + // TODO: need to use a Doltgres function provider, and need to implement the + // OpCode conversion for parsed ALIAS statements. + Skip: true, + SetUpScript: []string{ + `CREATE FUNCTION interpreted_alias(input TEXT) + RETURNS TEXT AS $$ + DECLARE + var1 TEXT; + var2 TEXT; + BEGIN + DECLARE + alias1 ALIAS FOR var1; + alias2 ALIAS FOR alias1; + alias3 ALIAS FOR input; + BEGIN + alias2 := alias3; + END; + RETURN var1; + END; + $$ LANGUAGE plpgsql; + `, + }, Assertions: []ScriptTestAssertion{ { Query: "SELECT interpreted_alias('123');",