Skip to content

Commit 8094922

Browse files
committed
Fixing bugs in function variable reference substitution
1 parent 14a2484 commit 8094922

File tree

2 files changed

+69
-26
lines changed

2 files changed

+69
-26
lines changed

server/plpgsql/statements.go

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ package plpgsql
1616

1717
import (
1818
"fmt"
19-
"strings"
19+
20+
pg_query "github.com/pganalyze/pg_query_go/v6"
2021
)
2122

2223
// Statement represents a PL/pgSQL statement.
@@ -43,15 +44,12 @@ func (Assignment) OperationSize() int32 {
4344

4445
// AppendOperations implements the interface Statement.
4546
func (stmt Assignment) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) {
46-
// TODO: figure out how I'm supposed to actually do this, rather than a search and replace
47-
expression := stmt.Expression
48-
var referencedVariables []string
49-
for varName := range stack.ListVariables() {
50-
if strings.Contains(expression, varName) {
51-
referencedVariables = append(referencedVariables, varName)
52-
}
53-
expression = strings.Replace(expression, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1)
47+
expression, referencedVariables, err := substituteVariableReferences(stmt.Expression, stack)
48+
if err != nil {
49+
// TODO: add an error return param instead of panicing
50+
panic(err)
5451
}
52+
5553
*ops = append(*ops, InterpreterOperation{
5654
OpCode: OpCode_Assign,
5755
PrimaryData: "SELECT " + expression + ";",
@@ -143,14 +141,12 @@ func (If) OperationSize() int32 {
143141

144142
// AppendOperations implements the interface Statement.
145143
func (stmt If) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) {
146-
condition := stmt.Condition
147-
var referencedVariables []string
148-
for varName := range stack.ListVariables() {
149-
if strings.Contains(condition, varName) {
150-
referencedVariables = append(referencedVariables, varName)
151-
}
152-
condition = strings.Replace(condition, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1)
144+
condition, referencedVariables, err := substituteVariableReferences(stmt.Condition, stack)
145+
if err != nil {
146+
// TODO: add an error return param instead of panicing
147+
panic(err)
153148
}
149+
154150
*ops = append(*ops, InterpreterOperation{
155151
OpCode: OpCode_If,
156152
PrimaryData: "SELECT " + condition + ";",
@@ -173,15 +169,12 @@ func (Return) OperationSize() int32 {
173169

174170
// AppendOperations implements the interface Statement.
175171
func (stmt Return) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) {
176-
// TODO: figure out how I'm supposed to actually do this, rather than a search and replace
177-
expression := stmt.Expression
178-
var referencedVariables []string
179-
for varName := range stack.ListVariables() {
180-
if strings.Contains(expression, varName) {
181-
referencedVariables = append(referencedVariables, varName)
182-
}
183-
expression = strings.Replace(expression, varName, fmt.Sprintf("$%d", len(referencedVariables)), 1)
172+
expression, referencedVariables, err := substituteVariableReferences(stmt.Expression, stack)
173+
if err != nil {
174+
// TODO: add an error return param instead of panicing
175+
panic(err)
184176
}
177+
185178
*ops = append(*ops, InterpreterOperation{
186179
OpCode: OpCode_Return,
187180
PrimaryData: "SELECT " + expression + ";",
@@ -204,3 +197,29 @@ func OperationSizeForStatements(stmts []Statement) int32 {
204197
}
205198
return total
206199
}
200+
201+
// substituteVariableReferences parses the specified |expression| and replaces
202+
// any token that matches a variable name in the |stack| with "$N", where N
203+
// indicates which variable in the returned |referenceVars| slice is used.
204+
func substituteVariableReferences(expression string, stack *InterpreterStack) (newExpression string, referencedVars []string, err error) {
205+
scanResult, err := pg_query.Scan(expression)
206+
if err != nil {
207+
return "", nil, err
208+
}
209+
210+
varMap := stack.ListVariables()
211+
for _, token := range scanResult.Tokens {
212+
substring := expression[token.Start:token.End]
213+
if _, ok := varMap[substring]; ok {
214+
// TODO: There's another bug here where the same variable could
215+
// be referenced multiple times in the expression and the
216+
// length of the slice won't be the correct index.
217+
referencedVars = append(referencedVars, substring)
218+
newExpression += fmt.Sprintf("$%d ", len(referencedVars))
219+
} else {
220+
newExpression += substring + " "
221+
}
222+
}
223+
224+
return newExpression, referencedVars, nil
225+
}

testing/go/create_function_test.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ $$ LANGUAGE plpgsql;`},
7070
},
7171
{
7272
Name: "Interpreter Alias Example",
73-
// TODO: need to use a Doltgres function provider, and need to implement the
74-
// OpCode conversion for parsed ALIAS statements.
73+
// TODO: Implement OpCode conversion for parsed ALIAS statements.
7574
Skip: true,
7675
SetUpScript: []string{
7776
`CREATE FUNCTION interpreted_alias(input TEXT)
@@ -99,5 +98,30 @@ $$ LANGUAGE plpgsql;`},
9998
},
10099
},
101100
},
101+
{
102+
// Tests that variable names are correctly substituted with references
103+
// to the variables when the function is parsed.
104+
Name: "Variable reference substitution",
105+
SetUpScript: []string{`
106+
CREATE FUNCTION test1(input TEXT) RETURNS TEXT AS $$
107+
DECLARE
108+
var1 TEXT;
109+
BEGIN
110+
var1 := 'input' || input;
111+
IF var1 = 'input' || input THEN
112+
RETURN var1 || 'var1';
113+
ELSE
114+
RETURN '!!!';
115+
END IF;
116+
END;
117+
$$ LANGUAGE plpgsql;`,
118+
},
119+
Assertions: []ScriptTestAssertion{
120+
{
121+
Query: "SELECT test1('Hello');",
122+
Expected: []sql.Row{{"inputHellovar1"}},
123+
},
124+
},
125+
},
102126
})
103127
}

0 commit comments

Comments
 (0)