@@ -16,7 +16,8 @@ package plpgsql
16
16
17
17
import (
18
18
"fmt"
19
- "strings"
19
+
20
+ pg_query "github.com/pganalyze/pg_query_go/v6"
20
21
)
21
22
22
23
// Statement represents a PL/pgSQL statement.
@@ -43,15 +44,12 @@ func (Assignment) OperationSize() int32 {
43
44
44
45
// AppendOperations implements the interface Statement.
45
46
func (stmt Assignment ) AppendOperations (ops * []InterpreterOperation , stack * InterpreterStack ) {
46
- // TODO: figure out how I'm supposed to actually do this, rather than a search and replace
47
- expression := stmt .Expression
48
- var referencedVariables []string
49
- for varName := range stack .ListVariables () {
50
- if strings .Contains (expression , varName ) {
51
- referencedVariables = append (referencedVariables , varName )
52
- }
53
- expression = strings .Replace (expression , varName , fmt .Sprintf ("$%d" , len (referencedVariables )), 1 )
47
+ expression , referencedVariables , err := substituteVariableReferences (stmt .Expression , stack )
48
+ if err != nil {
49
+ // TODO: add an error return param instead of panicing
50
+ panic (err )
54
51
}
52
+
55
53
* ops = append (* ops , InterpreterOperation {
56
54
OpCode : OpCode_Assign ,
57
55
PrimaryData : "SELECT " + expression + ";" ,
@@ -143,14 +141,12 @@ func (If) OperationSize() int32 {
143
141
144
142
// AppendOperations implements the interface Statement.
145
143
func (stmt If ) AppendOperations (ops * []InterpreterOperation , stack * InterpreterStack ) {
146
- condition := stmt .Condition
147
- var referencedVariables []string
148
- for varName := range stack .ListVariables () {
149
- if strings .Contains (condition , varName ) {
150
- referencedVariables = append (referencedVariables , varName )
151
- }
152
- condition = strings .Replace (condition , varName , fmt .Sprintf ("$%d" , len (referencedVariables )), 1 )
144
+ condition , referencedVariables , err := substituteVariableReferences (stmt .Condition , stack )
145
+ if err != nil {
146
+ // TODO: add an error return param instead of panicing
147
+ panic (err )
153
148
}
149
+
154
150
* ops = append (* ops , InterpreterOperation {
155
151
OpCode : OpCode_If ,
156
152
PrimaryData : "SELECT " + condition + ";" ,
@@ -173,15 +169,12 @@ func (Return) OperationSize() int32 {
173
169
174
170
// AppendOperations implements the interface Statement.
175
171
func (stmt Return ) AppendOperations (ops * []InterpreterOperation , stack * InterpreterStack ) {
176
- // TODO: figure out how I'm supposed to actually do this, rather than a search and replace
177
- expression := stmt .Expression
178
- var referencedVariables []string
179
- for varName := range stack .ListVariables () {
180
- if strings .Contains (expression , varName ) {
181
- referencedVariables = append (referencedVariables , varName )
182
- }
183
- expression = strings .Replace (expression , varName , fmt .Sprintf ("$%d" , len (referencedVariables )), 1 )
172
+ expression , referencedVariables , err := substituteVariableReferences (stmt .Expression , stack )
173
+ if err != nil {
174
+ // TODO: add an error return param instead of panicing
175
+ panic (err )
184
176
}
177
+
185
178
* ops = append (* ops , InterpreterOperation {
186
179
OpCode : OpCode_Return ,
187
180
PrimaryData : "SELECT " + expression + ";" ,
@@ -204,3 +197,29 @@ func OperationSizeForStatements(stmts []Statement) int32 {
204
197
}
205
198
return total
206
199
}
200
+
201
+ // substituteVariableReferences parses the specified |expression| and replaces
202
+ // any token that matches a variable name in the |stack| with "$N", where N
203
+ // indicates which variable in the returned |referenceVars| slice is used.
204
+ func substituteVariableReferences (expression string , stack * InterpreterStack ) (newExpression string , referencedVars []string , err error ) {
205
+ scanResult , err := pg_query .Scan (expression )
206
+ if err != nil {
207
+ return "" , nil , err
208
+ }
209
+
210
+ varMap := stack .ListVariables ()
211
+ for _ , token := range scanResult .Tokens {
212
+ substring := expression [token .Start :token .End ]
213
+ if _ , ok := varMap [substring ]; ok {
214
+ // TODO: There's another bug here where the same variable could
215
+ // be referenced multiple times in the expression and the
216
+ // length of the slice won't be the correct index.
217
+ referencedVars = append (referencedVars , substring )
218
+ newExpression += fmt .Sprintf ("$%d " , len (referencedVars ))
219
+ } else {
220
+ newExpression += substring + " "
221
+ }
222
+ }
223
+
224
+ return newExpression , referencedVars , nil
225
+ }
0 commit comments