diff --git a/server/plpgsql/statements.go b/server/plpgsql/statements.go index 5ee6700e92..62cec4eede 100644 --- a/server/plpgsql/statements.go +++ b/server/plpgsql/statements.go @@ -16,7 +16,8 @@ package plpgsql import ( "fmt" - "strings" + + pg_query "github.com/pganalyze/pg_query_go/v6" ) // Statement represents a PL/pgSQL statement. @@ -43,15 +44,12 @@ func (Assignment) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt Assignment) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) { - // TODO: figure out how I'm supposed to actually do this, rather than a search and replace - expression := stmt.Expression - var referencedVariables []string - for varName := range stack.ListVariables() { - if strings.Contains(expression, varName) { - referencedVariables = append(referencedVariables, varName) - } - expression = strings.Replace(expression, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1) + expression, referencedVariables, err := substituteVariableReferences(stmt.Expression, stack) + if err != nil { + // TODO: add an error return param instead of panicing + panic(err) } + *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_Assign, PrimaryData: "SELECT " + expression + ";", @@ -143,14 +141,12 @@ func (If) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt If) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) { - condition := stmt.Condition - var referencedVariables []string - for varName := range stack.ListVariables() { - if strings.Contains(condition, varName) { - referencedVariables = append(referencedVariables, varName) - } - condition = strings.Replace(condition, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1) + condition, referencedVariables, err := substituteVariableReferences(stmt.Condition, stack) + if err != nil { + // TODO: add an error return param instead of panicing + panic(err) } + *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_If, PrimaryData: "SELECT " + condition + ";", @@ -173,15 +169,12 @@ func (Return) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt Return) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) { - // TODO: figure out how I'm supposed to actually do this, rather than a search and replace - expression := stmt.Expression - var referencedVariables []string - for varName := range stack.ListVariables() { - if strings.Contains(expression, varName) { - referencedVariables = append(referencedVariables, varName) - } - expression = strings.Replace(expression, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1) + expression, referencedVariables, err := substituteVariableReferences(stmt.Expression, stack) + if err != nil { + // TODO: add an error return param instead of panicing + panic(err) } + *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_Return, PrimaryData: "SELECT " + expression + ";", @@ -204,3 +197,29 @@ func OperationSizeForStatements(stmts []Statement) int32 { } return total } + +// substituteVariableReferences parses the specified |expression| and replaces +// any token that matches a variable name in the |stack| with "$N", where N +// indicates which variable in the returned |referenceVars| slice is used. +func substituteVariableReferences(expression string, stack *InterpreterStack) (newExpression string, referencedVars []string, err error) { + scanResult, err := pg_query.Scan(expression) + if err != nil { + return "", nil, err + } + + varMap := stack.ListVariables() + for _, token := range scanResult.Tokens { + substring := expression[token.Start:token.End] + if _, ok := varMap[substring]; ok { + // TODO: There's another bug here where the same variable could + // be referenced multiple times in the expression and the + // length of the slice won't be the correct index. + referencedVars = append(referencedVars, substring) + newExpression += fmt.Sprintf("$%d ", len(referencedVars)) + } else { + newExpression += substring + " " + } + } + + return newExpression, referencedVars, nil +} diff --git a/testing/go/create_function_test.go b/testing/go/create_function_test.go index 4976065709..bf773a782d 100644 --- a/testing/go/create_function_test.go +++ b/testing/go/create_function_test.go @@ -70,8 +70,7 @@ $$ LANGUAGE plpgsql;`}, }, { Name: "Interpreter Alias Example", - // TODO: need to use a Doltgres function provider, and need to implement the - // OpCode conversion for parsed ALIAS statements. + // TODO: Implement OpCode conversion for parsed ALIAS statements. Skip: true, SetUpScript: []string{ `CREATE FUNCTION interpreted_alias(input TEXT) @@ -99,5 +98,30 @@ $$ LANGUAGE plpgsql;`}, }, }, }, + { + // Tests that variable names are correctly substituted with references + // to the variables when the function is parsed. + Name: "Variable reference substitution", + SetUpScript: []string{` +CREATE FUNCTION test1(input TEXT) RETURNS TEXT AS $$ +DECLARE + var1 TEXT; +BEGIN + var1 := 'input' || input; + IF var1 = 'input' || input THEN + RETURN var1 || 'var1'; + ELSE + RETURN '!!!'; + END IF; +END; +$$ LANGUAGE plpgsql;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT test1('Hello');", + Expected: []sql.Row{{"inputHellovar1"}}, + }, + }, + }, }) }