Skip to content

Commit

Permalink
Fixing bugs in function variable reference substitution
Browse files Browse the repository at this point in the history
  • Loading branch information
fulghum committed Feb 7, 2025
1 parent 14a2484 commit 8094922
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 26 deletions.
67 changes: 43 additions & 24 deletions server/plpgsql/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ package plpgsql

import (
"fmt"
"strings"

pg_query "github.com/pganalyze/pg_query_go/v6"
)

// Statement represents a PL/pgSQL statement.
Expand All @@ -43,15 +44,12 @@ func (Assignment) OperationSize() int32 {

// AppendOperations implements the interface Statement.
func (stmt Assignment) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) {
// TODO: figure out how I'm supposed to actually do this, rather than a search and replace
expression := stmt.Expression
var referencedVariables []string
for varName := range stack.ListVariables() {
if strings.Contains(expression, varName) {
referencedVariables = append(referencedVariables, varName)
}
expression = strings.Replace(expression, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1)
expression, referencedVariables, err := substituteVariableReferences(stmt.Expression, stack)
if err != nil {
// TODO: add an error return param instead of panicing
panic(err)
}

*ops = append(*ops, InterpreterOperation{
OpCode: OpCode_Assign,
PrimaryData: "SELECT " + expression + ";",
Expand Down Expand Up @@ -143,14 +141,12 @@ func (If) OperationSize() int32 {

// AppendOperations implements the interface Statement.
func (stmt If) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) {
condition := stmt.Condition
var referencedVariables []string
for varName := range stack.ListVariables() {
if strings.Contains(condition, varName) {
referencedVariables = append(referencedVariables, varName)
}
condition = strings.Replace(condition, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1)
condition, referencedVariables, err := substituteVariableReferences(stmt.Condition, stack)
if err != nil {
// TODO: add an error return param instead of panicing
panic(err)
}

*ops = append(*ops, InterpreterOperation{
OpCode: OpCode_If,
PrimaryData: "SELECT " + condition + ";",
Expand All @@ -173,15 +169,12 @@ func (Return) OperationSize() int32 {

// AppendOperations implements the interface Statement.
func (stmt Return) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) {
// TODO: figure out how I'm supposed to actually do this, rather than a search and replace
expression := stmt.Expression
var referencedVariables []string
for varName := range stack.ListVariables() {
if strings.Contains(expression, varName) {
referencedVariables = append(referencedVariables, varName)
}
expression = strings.Replace(expression, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1)
expression, referencedVariables, err := substituteVariableReferences(stmt.Expression, stack)
if err != nil {
// TODO: add an error return param instead of panicing
panic(err)
}

*ops = append(*ops, InterpreterOperation{
OpCode: OpCode_Return,
PrimaryData: "SELECT " + expression + ";",
Expand All @@ -204,3 +197,29 @@ func OperationSizeForStatements(stmts []Statement) int32 {
}
return total
}

// substituteVariableReferences parses the specified |expression| and replaces
// any token that matches a variable name in the |stack| with "$N", where N
// indicates which variable in the returned |referenceVars| slice is used.
func substituteVariableReferences(expression string, stack *InterpreterStack) (newExpression string, referencedVars []string, err error) {
scanResult, err := pg_query.Scan(expression)
if err != nil {
return "", nil, err
}

varMap := stack.ListVariables()
for _, token := range scanResult.Tokens {
substring := expression[token.Start:token.End]
if _, ok := varMap[substring]; ok {
// TODO: There's another bug here where the same variable could
// be referenced multiple times in the expression and the
// length of the slice won't be the correct index.
referencedVars = append(referencedVars, substring)
newExpression += fmt.Sprintf("$%d ", len(referencedVars))
} else {
newExpression += substring + " "
}
}

return newExpression, referencedVars, nil
}
28 changes: 26 additions & 2 deletions testing/go/create_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ $$ LANGUAGE plpgsql;`},
},
{
Name: "Interpreter Alias Example",
// TODO: need to use a Doltgres function provider, and need to implement the
// OpCode conversion for parsed ALIAS statements.
// TODO: Implement OpCode conversion for parsed ALIAS statements.
Skip: true,
SetUpScript: []string{
`CREATE FUNCTION interpreted_alias(input TEXT)
Expand Down Expand Up @@ -99,5 +98,30 @@ $$ LANGUAGE plpgsql;`},
},
},
},
{
// Tests that variable names are correctly substituted with references
// to the variables when the function is parsed.
Name: "Variable reference substitution",
SetUpScript: []string{`
CREATE FUNCTION test1(input TEXT) RETURNS TEXT AS $$
DECLARE
var1 TEXT;
BEGIN
var1 := 'input' || input;
IF var1 = 'input' || input THEN
RETURN var1 || 'var1';
ELSE
RETURN '!!!';
END IF;
END;
$$ LANGUAGE plpgsql;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT test1('Hello');",
Expected: []sql.Row{{"inputHellovar1"}},
},
},
},
})
}

0 comments on commit 8094922

Please sign in to comment.